{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Memory Augmented Neural Network using Omniglot Dataset\n",
    "\n",
    "\n",
    "In this tutorial, we will do following things step by step:\n",
    "1. Data Preprocessing: Creating Pairs.\n",
    "2. Create a Memory Augmented Enural Network\n",
    "3. Train it using Omniglot dataset. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: scipy==1.1.0 in /Users/sjadon/.local/lib/python3.6/site-packages\n",
      "Requirement already satisfied: numpy>=1.8.2 in /opt/anaconda3/envs/project09/lib/python3.6/site-packages (from scipy==1.1.0)\n",
      "\u001b[33mYou are using pip version 9.0.1, however version 20.0.2 is available.\n",
      "You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\n",
      "Requirement already satisfied: Pillow in /opt/anaconda3/envs/project09/lib/python3.6/site-packages\n",
      "\u001b[33mYou are using pip version 9.0.1, however version 20.0.2 is available.\n",
      "You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "!pip install scipy==1.1.0\n",
    "!pip install Pillow"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Step1: Lets first import all libraries needed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "ename": "ImportError",
     "evalue": "cannot import name 'imresize'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mImportError\u001b[0m                               Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-15-c5b50f98ed59>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mpandas\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpyplot\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mutils\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mut\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Desktop/github/Hands-on-One-Shot-Learning/Ch04-ModelsBasedMethods/utils.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mscipy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmisc\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mimresize\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpyplot\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mimread\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mos\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mlistdir\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msplitext\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mrandom\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mseed\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mshuffle\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mImportError\u001b[0m: cannot import name 'imresize'"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import utils as ut\n",
    "import os\n",
    "import time\n",
    "from scipy.misc import imresize\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "from torch.nn import functional as F\n",
    "import torch.optim as optim\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step2 : Load Data\n",
    "\n",
    "We are reading images from two folders named 'image_background', 'image_evaluation' defined in 'data' directory\n",
    "\n",
    "Dataset is divided into 1423 charcters images for training and rest for validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "width = 20\n",
    "\n",
    "# gather data paths\n",
    "subfolds = ut.extend_children('data','')\n",
    "datafolds = [subfolds[0],subfolds[1]]\n",
    "alphabets = ut.extend_generation(datafolds,'')\n",
    "charpaths = ut.extend_generation(alphabets,'')\n",
    "chars_dataset = [v.split('/')[2]+'/'+v.split('/')[3] for v in charpaths]\n",
    "\n",
    "# index-value conversion dictionaries for character set\n",
    "i2v = {i:v for i, v in enumerate(chars_dataset)}\n",
    "v2i = {v:i for i, v in enumerate(chars_dataset)}\n",
    "\n",
    "# get size of dataset\n",
    "mc_dataset = len(charpaths)\n",
    "print(mc_dataset,'total character classes')\n",
    "\n",
    "# train split\n",
    "mc_train = 1423\n",
    "chars_train = chars_dataset[:mc_train]\n",
    "classes_train = [v2i[v] for v in chars_train]\n",
    "\n",
    "# validation split\n",
    "mc_val = mc_dataset-mc_train\n",
    "chars_val = chars_dataset[-mc_val:]\n",
    "classes_val = [v2i[v] for v in chars_val]\n",
    "\n",
    "\n",
    "\n",
    "print('%s characters assigned for training, %s characters assigned for validation'%(mc_train,mc_val))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 3: Read Data\n",
    "\n",
    "All character images are read and saved in the imgs_dataset variable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'charpaths' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-2-eb9abe61d6e6>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;31m# load images from character paths\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0mimgs_dataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcharfold\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcharpaths\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      4\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m%\u001b[0m\u001b[0;36m200\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m         \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'%s/%s character folders loaded'\u001b[0m\u001b[0;34m%\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mmc_dataset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'charpaths' is not defined"
     ]
    }
   ],
   "source": [
    "# load images from character paths\n",
    "imgs_dataset = []\n",
    "for i, charfold in enumerate(charpaths):\n",
    "    if i%200==0:\n",
    "        print('%s/%s character folders loaded'%(i,mc_dataset))\n",
    "    imgs_dataset.append([ ut.load_image(imgpath,(width,width))/127.5-1 for imgpath in ut.extend_children(charfold,'.png') ] )\n",
    "# access imgs_dataset [ character index ] [ sample index ] [ row,col ]\n",
    "\n",
    "# split images between train and validation sets\n",
    "imgs_train = imgs_dataset[:mc_train]\n",
    "imgs_val = imgs_dataset[-mc_val:]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 4: Let's Initialize our Hyper-paramters\n",
    "\n",
    "We define all the hyper-parameters and dimensions of all the variables here.\n",
    "* n_classes is 5 (in each batch we have 5 types of characters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "n_classes = 5\n",
    "memory_size = 128 # number of features per entry\n",
    "memory_dim = 40 # number of entries in memory\n",
    "learning_rate = 1e-3\n",
    "batch_size = 16\n",
    "\n",
    "n_inputs = width*width+n_classes # input to LSTM cell.\n",
    "n_hnodes = 200 # LSTM cell size\n",
    "n_outputs = n_classes\n",
    "mem_size = memory_size # number of rows in external memory \n",
    "mem_dim = memory_dim # number of columns in external memory\n",
    "n_reads = 4 # number of heads. \n",
    "\n",
    "n_xh = n_inputs+n_hnodes # inputs to LSTM cell (previous state + input image dim)\n",
    "n_rd = n_reads*mem_dim\n",
    "n_hr = n_hnodes+n_rd\n",
    "gamma = 0.95"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 5: Let's Create all Trainable parameters for network and initialize them.\n",
    "\n",
    "We define all the hyper-parameters and dimensions of all the variables here.\n",
    "\n",
    "This defines three types of weights:\n",
    "* LSTM weights (we use LSTM controller, a feed-forward controller could also be used)\n",
    "* Read and Write key weights\n",
    "* Output fully connected layer weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# LSTM gates (4 of them: weights and biases for each)\n",
    "W_gf = torch.Tensor(n_xh, n_hnodes).uniform_(-1., 1.).requires_grad_()\n",
    "b_gf = torch.Tensor(n_hnodes).uniform_(-1., 1.).requires_grad_()\n",
    "W_gi = torch.Tensor(n_xh,n_hnodes).uniform_(-1., 1.).requires_grad_()\n",
    "b_gi = torch.Tensor(n_hnodes).uniform_(-1., 1.).requires_grad_()\n",
    "W_go = torch.Tensor(n_xh,n_hnodes).uniform_(-1., 1.).requires_grad_()\n",
    "b_go = torch.Tensor(n_hnodes).uniform_(-1., 1.).requires_grad_()\n",
    "W_u = torch.Tensor(n_xh,n_hnodes).uniform_(-1., 1.).requires_grad_()\n",
    "b_u = torch.Tensor(n_hnodes).uniform_(-1., 1.).requires_grad_()\n",
    "# Controller Weights\n",
    "W_kr = torch.Tensor(n_hnodes,n_rd).uniform_(-1., 1.).requires_grad_()\n",
    "b_kr = torch.Tensor(n_rd).uniform_(-1., 1.).requires_grad_()\n",
    "W_kw = torch.Tensor(n_hnodes,n_rd).uniform_(-1., 1.).requires_grad_()\n",
    "b_kw = torch.Tensor(n_rd).uniform_(-1., 1.).requires_grad_()\n",
    "W_ga = torch.Tensor(n_hnodes,n_reads).uniform_(-1., 1.).requires_grad_()\n",
    "b_ga = torch.Tensor(n_reads).uniform_(-1., 1.).requires_grad_()\n",
    "# logit weights\n",
    "W_o = torch.Tensor(n_hr,n_outputs).uniform_(-1., 1.).requires_grad_()\n",
    "b_o = torch.Tensor(n_outputs).uniform_(-1., 1.).requires_grad_()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 6: Define Model\n",
    "\n",
    "The next cell defines the forward pass of the network. \n",
    "It works as follows:\n",
    "* the net() function receives input X which is sliced along dim 0 which is the time dimension\n",
    "* we sequentially process each time step in run_one_step() function and collect state vectors of each step's output\n",
    "* The predictions are made at each time step with fully connected output layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_state0():\n",
    "    # memory variables (not trainable.)\n",
    "    # initialize memory and LSTM states with zero. \n",
    "    return(\n",
    "        torch.FloatTensor(1e-6*np.random.rand(batch_size,mem_size,mem_dim)),\n",
    "        torch.FloatTensor(np.zeros((batch_size,n_hnodes))),\n",
    "        torch.FloatTensor(np.zeros((batch_size,n_hnodes))),\n",
    "        torch.FloatTensor(np.zeros((batch_size,mem_size))),\n",
    "        torch.FloatTensor(np.zeros((batch_size,n_reads,mem_size))),\n",
    "        torch.FloatTensor(np.zeros((batch_size,n_reads,mem_dim))),\n",
    "    )\n",
    "\n",
    "def run_one_step(X_t, state):\n",
    "    # Run one step of the episode.\n",
    "    M_tm1, h_tm1, c_tm1, wu_tm1, wr_tm1, r_tm1 = state\n",
    "    X_t_r = X_t.view(-1,n_inputs)\n",
    "    xh = torch.cat((X_t_r,h_tm1),1)\n",
    "    gf = torch.sigmoid(torch.matmul(xh,W_gf) + b_gf)\n",
    "    gi = torch.sigmoid(torch.matmul(xh,W_gi) + b_gi)\n",
    "    go = torch.sigmoid(torch.matmul(xh,W_go) + b_go)\n",
    "    u_t = torch.tanh(torch.matmul(xh,W_u) + b_u)\n",
    "    c_t = c_tm1*gf + u_t*gi\n",
    "    h_t = c_t*go\n",
    "    kr_t = torch.tanh(torch.matmul(c_t,W_kr) + b_kr).view(batch_size,n_reads,mem_dim)\n",
    "    kw_t = torch.tanh(torch.matmul(c_t,W_kw) + b_kw).view(batch_size,n_reads,mem_dim)\n",
    "    k_norm = torch.norm(kr_t, dim=2, keepdim=True)\n",
    "    m_norm = torch.norm(M_tm1, dim=2, keepdim=True)\n",
    "    inner_prod = torch.matmul(kr_t, M_tm1.permute(0,2,1))\n",
    "    norm_prod = torch.matmul(k_norm, m_norm.permute(0,2,1))\n",
    "    wr_t = F.softmax(inner_prod/norm_prod)\n",
    "    wu_1 = wu_tm1*gamma + torch.sum(wr_t, dim=1)\n",
    "    r_t = torch.matmul(wr_t,M_tm1)\n",
    "    ga = torch.unsqueeze(torch.sigmoid(torch.matmul(h_t,W_ga)+b_ga),2)\n",
    "    _, wlu_inds = torch.topk(-1*wu_1,k=n_reads)\n",
    "    wlu_t = torch.sum(F.one_hot(wlu_inds, mem_size).type(torch.FloatTensor),dim=1,keepdim=True)\n",
    "    ww_t = wr_t*ga + wlu_t*(1-ga)\n",
    "    wu_t = wu_1 + torch.sum(ww_t, dim=1)\n",
    "    M_1 = M_tm1 * (-1*wlu_t).permute(0,2,1)\n",
    "    M_t = M_1 + torch.matmul(ww_t.permute(0,2,1), kw_t)\n",
    "    st8_t = (M_t, h_t, c_t, wu_t, wr_t, r_t)\n",
    "    return st8_t\n",
    "    \n",
    "\n",
    "    \n",
    "def net(X=None, y=None):\n",
    "    # X is of shape (batch_size, None, width, width)\n",
    "    a = np.arange(100*16*405).reshape((100,16,405)).astype(np.float32)\n",
    "    X = torch.from_numpy(a)\n",
    "    \n",
    "    state0 = get_state0()\n",
    "    curr_state = state0\n",
    "    \n",
    "    # Collect output of state vectors in each time step.\n",
    "    M_f = []\n",
    "    h_f = []\n",
    "    c_f = []\n",
    "    wu_f = []\n",
    "    wr_f = []\n",
    "    r_f = []\n",
    "    \n",
    "    for i in range(X.shape[0]):\n",
    "        curr_state = run_one_step(X[i], curr_state)\n",
    "        M_f.append(curr_state[0])\n",
    "        h_f.append(curr_state[1])\n",
    "        c_f.append(curr_state[2])\n",
    "        wu_f.append(curr_state[3])\n",
    "        wr_f.append(curr_state[4])\n",
    "        r_f.append(curr_state[5])\n",
    "        \n",
    "    M_f = torch.stack(M_f)\n",
    "    h_f = torch.stack(h_f)\n",
    "    c_f = torch.stack(c_f)\n",
    "    wu_f = torch.stack(wu_f)\n",
    "    wr_f = torch.stack(wr_f)\n",
    "    r_f = torch.stack(r_f)\n",
    "\n",
    "    hr = torch.cat((h_f, r_f.view(-1,batch_size,n_rd)),2)\n",
    "    o_f = torch.tensordot(hr,W_o,1)+b_o\n",
    "    return (M_f, h_f, c_f, wu_f, wr_f, r_f, o_f)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Step 7: Initialize optimizer and criteria for  training\n",
    "\n",
    "We use Adam optimizer on cross entropy loss "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "optimizer = optim.Adam([W_gf, b_gf, W_gi, b_gi, W_go, b_go, W_u, b_u, W_kr,b_kr, W_kw, b_kw,\n",
    "                      W_ga, b_ga, W_o, b_o], lr=learning_rate)\n",
    "    \n",
    "criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare graph input\n",
    "\n",
    "Y-labels predicted at each time step are appended to input at next time step for better signal.\n",
    "* During training, we have access to y_labels at time step. Shift the y_labels by one time step and append this to the input to prepare input data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_graph_input(X_train, y_train):\n",
    "    X = np.transpose(X_train.reshape(batch_size,-1,width*width),(1,0,2))\n",
    "    y_labels = np.transpose(y_train, (1,0,2))\n",
    "    y_labels_shifted = np.concatenate((np.zeros((1,batch_size,n_classes)), y_labels[:-1,:]),0)\n",
    "    X = np.concatenate((X,y_labels_shifted),-1)\n",
    "    y = np.argmax(y_labels, -1)\n",
    "    return torch.from_numpy(X), torch.from_numpy(y)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Let's Run Experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 starting..\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/sjadon/anaconda3/envs/mysite/lib/python3.6/site-packages/ipykernel_launcher.py:30: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 / Batch (1/17) / Loss 182.969574\n",
      "Epoch 1 / Batch (2/17) / Loss 178.152054\n",
      "Epoch 1 / Batch (3/17) / Loss 177.749924\n",
      "Epoch 1 / Batch (4/17) / Loss 172.039276\n",
      "Epoch 1 / Batch (5/17) / Loss 163.774048\n",
      "Epoch 1 / Batch (6/17) / Loss 162.877548\n",
      "Epoch 1 / Batch (7/17) / Loss 162.991608\n",
      "Epoch 1 / Batch (8/17) / Loss 154.819504\n",
      "Epoch 1 / Batch (9/17) / Loss 153.378937\n",
      "Epoch 1 / Batch (10/17) / Loss 153.685257\n",
      "Epoch 1 / Batch (11/17) / Loss 154.252106\n",
      "Epoch 1 / Batch (12/17) / Loss 146.768921\n",
      "Epoch 1 / Batch (13/17) / Loss 135.351486\n",
      "Epoch 1 / Batch (14/17) / Loss 137.366592\n",
      "Epoch 1 / Batch (15/17) / Loss 135.110306\n",
      "Epoch 1 / Batch (16/17) / Loss 130.907349\n",
      "Epoch 1 / Batch (17/17) / Loss 131.145569\n",
      "Epoch 1 complete, 8.947108745574951 seconds elapsed\n",
      "Epoch 2 starting..\n",
      "Epoch 2 / Batch (1/17) / Loss 127.797226\n",
      "Epoch 2 / Batch (2/17) / Loss 124.994247\n",
      "Epoch 2 / Batch (3/17) / Loss 128.439606\n",
      "Epoch 2 / Batch (4/17) / Loss 124.740242\n",
      "Epoch 2 / Batch (5/17) / Loss 122.984512\n",
      "Epoch 2 / Batch (6/17) / Loss 122.479240\n",
      "Epoch 2 / Batch (7/17) / Loss 121.243927\n",
      "Epoch 2 / Batch (8/17) / Loss 117.775108\n",
      "Epoch 2 / Batch (9/17) / Loss 114.443611\n",
      "Epoch 2 / Batch (10/17) / Loss 116.625008\n",
      "Epoch 2 / Batch (11/17) / Loss 119.237480\n",
      "Epoch 2 / Batch (12/17) / Loss 112.099030\n",
      "Epoch 2 / Batch (13/17) / Loss 108.810661\n",
      "Epoch 2 / Batch (14/17) / Loss 110.960075\n",
      "Epoch 2 / Batch (15/17) / Loss 108.040001\n",
      "Epoch 2 / Batch (16/17) / Loss 104.902718\n",
      "Epoch 2 / Batch (17/17) / Loss 105.576149\n",
      "Epoch 2 complete, 10.407933235168457 seconds elapsed\n",
      "Epoch 3 starting..\n",
      "Epoch 3 / Batch (1/17) / Loss 100.779404\n",
      "Epoch 3 / Batch (2/17) / Loss 103.359726\n",
      "Epoch 3 / Batch (3/17) / Loss 102.713554\n",
      "Epoch 3 / Batch (4/17) / Loss 103.232315\n",
      "Epoch 3 / Batch (5/17) / Loss 102.222473\n",
      "Epoch 3 / Batch (6/17) / Loss 98.093407\n",
      "Epoch 3 / Batch (7/17) / Loss 98.872566\n",
      "Epoch 3 / Batch (8/17) / Loss 98.082657\n",
      "Epoch 3 / Batch (9/17) / Loss 94.891541\n",
      "Epoch 3 / Batch (10/17) / Loss 99.217018\n",
      "Epoch 3 / Batch (11/17) / Loss 95.829063\n",
      "Epoch 3 / Batch (12/17) / Loss 92.734657\n",
      "Epoch 3 / Batch (13/17) / Loss 93.787048\n",
      "Epoch 3 / Batch (14/17) / Loss 93.988495\n",
      "Epoch 3 / Batch (15/17) / Loss 92.050018\n",
      "Epoch 3 / Batch (16/17) / Loss 87.660271\n",
      "Epoch 3 / Batch (17/17) / Loss 90.650192\n",
      "Epoch 3 complete, 10.343433856964111 seconds elapsed\n",
      "Epoch 4 starting..\n",
      "Epoch 4 / Batch (1/17) / Loss 88.841667\n",
      "Epoch 4 / Batch (2/17) / Loss 88.310379\n",
      "Epoch 4 / Batch (3/17) / Loss 83.887497\n",
      "Epoch 4 / Batch (4/17) / Loss 84.118431\n",
      "Epoch 4 / Batch (5/17) / Loss 81.824417\n",
      "Epoch 4 / Batch (6/17) / Loss 83.589111\n",
      "Epoch 4 / Batch (7/17) / Loss 85.429962\n",
      "Epoch 4 / Batch (8/17) / Loss 82.665047\n",
      "Epoch 4 / Batch (9/17) / Loss 78.521149\n",
      "Epoch 4 / Batch (10/17) / Loss 79.834656\n",
      "Epoch 4 / Batch (11/17) / Loss 80.544601\n",
      "Epoch 4 / Batch (12/17) / Loss 76.633850\n",
      "Epoch 4 / Batch (13/17) / Loss 73.699097\n",
      "Epoch 4 / Batch (14/17) / Loss 75.144051\n",
      "Epoch 4 / Batch (15/17) / Loss 75.683395\n",
      "Epoch 4 / Batch (16/17) / Loss 75.213860\n",
      "Epoch 4 / Batch (17/17) / Loss 73.960480\n",
      "Epoch 4 complete, 10.593315839767456 seconds elapsed\n",
      "Epoch 5 starting..\n",
      "Epoch 5 / Batch (1/17) / Loss 74.243095\n",
      "Epoch 5 / Batch (2/17) / Loss 73.208855\n",
      "Epoch 5 / Batch (3/17) / Loss 71.174980\n",
      "Epoch 5 / Batch (4/17) / Loss 69.349098\n",
      "Epoch 5 / Batch (5/17) / Loss 69.872429\n",
      "Epoch 5 / Batch (6/17) / Loss 69.297806\n",
      "Epoch 5 / Batch (7/17) / Loss 65.475555\n",
      "Epoch 5 / Batch (8/17) / Loss 63.702423\n",
      "Epoch 5 / Batch (9/17) / Loss 68.804001\n",
      "Epoch 5 / Batch (10/17) / Loss 66.416946\n",
      "Epoch 5 / Batch (11/17) / Loss 62.396027\n",
      "Epoch 5 / Batch (12/17) / Loss 63.417564\n",
      "Epoch 5 / Batch (13/17) / Loss 61.261486\n",
      "Epoch 5 / Batch (14/17) / Loss 61.478539\n",
      "Epoch 5 / Batch (15/17) / Loss 64.617935\n",
      "Epoch 5 / Batch (16/17) / Loss 58.671997\n",
      "Epoch 5 / Batch (17/17) / Loss 57.954525\n",
      "Epoch 5 complete, 10.732754945755005 seconds elapsed\n",
      "Epoch 6 starting..\n",
      "Epoch 6 / Batch (1/17) / Loss 58.018036\n",
      "Epoch 6 / Batch (2/17) / Loss 59.246098\n",
      "Epoch 6 / Batch (3/17) / Loss 55.377407\n",
      "Epoch 6 / Batch (4/17) / Loss 57.655319\n",
      "Epoch 6 / Batch (5/17) / Loss 54.089630\n",
      "Epoch 6 / Batch (6/17) / Loss 53.961586\n",
      "Epoch 6 / Batch (7/17) / Loss 51.245995\n",
      "Epoch 6 / Batch (8/17) / Loss 53.143158\n",
      "Epoch 6 / Batch (9/17) / Loss 53.302555\n",
      "Epoch 6 / Batch (10/17) / Loss 49.422802\n",
      "Epoch 6 / Batch (11/17) / Loss 49.720566\n",
      "Epoch 6 / Batch (12/17) / Loss 51.025616\n",
      "Epoch 6 / Batch (13/17) / Loss 49.490211\n",
      "Epoch 6 / Batch (14/17) / Loss 48.078384\n",
      "Epoch 6 / Batch (15/17) / Loss 46.670063\n",
      "Epoch 6 / Batch (16/17) / Loss 46.331905\n",
      "Epoch 6 / Batch (17/17) / Loss 45.812626\n",
      "Epoch 6 complete, 10.956331014633179 seconds elapsed\n",
      "Epoch 7 starting..\n",
      "Epoch 7 / Batch (1/17) / Loss 45.674374\n",
      "Epoch 7 / Batch (2/17) / Loss 43.447594\n",
      "Epoch 7 / Batch (3/17) / Loss 42.427979\n",
      "Epoch 7 / Batch (4/17) / Loss 44.263512\n",
      "Epoch 7 / Batch (5/17) / Loss 40.912029\n",
      "Epoch 7 / Batch (6/17) / Loss 42.132076\n",
      "Epoch 7 / Batch (7/17) / Loss 42.468571\n",
      "Epoch 7 / Batch (8/17) / Loss 37.692814\n",
      "Epoch 7 / Batch (9/17) / Loss 41.475685\n",
      "Epoch 7 / Batch (10/17) / Loss 41.063908\n",
      "Epoch 7 / Batch (11/17) / Loss 38.299091\n",
      "Epoch 7 / Batch (12/17) / Loss 38.551674\n",
      "Epoch 7 / Batch (13/17) / Loss 38.184181\n",
      "Epoch 7 / Batch (14/17) / Loss 38.125313\n",
      "Epoch 7 / Batch (15/17) / Loss 37.934402\n",
      "Epoch 7 / Batch (16/17) / Loss 37.273670\n",
      "Epoch 7 / Batch (17/17) / Loss 36.511684\n",
      "Epoch 7 complete, 11.228466033935547 seconds elapsed\n",
      "Epoch 8 starting..\n",
      "Epoch 8 / Batch (1/17) / Loss 35.210373\n",
      "Epoch 8 / Batch (2/17) / Loss 34.234993\n",
      "Epoch 8 / Batch (3/17) / Loss 33.107655\n",
      "Epoch 8 / Batch (4/17) / Loss 34.711689\n",
      "Epoch 8 / Batch (5/17) / Loss 33.752159\n",
      "Epoch 8 / Batch (6/17) / Loss 34.336811\n",
      "Epoch 8 / Batch (7/17) / Loss 34.778172\n",
      "Epoch 8 / Batch (8/17) / Loss 34.690495\n",
      "Epoch 8 / Batch (9/17) / Loss 33.412960\n",
      "Epoch 8 / Batch (10/17) / Loss 32.130779\n",
      "Epoch 8 / Batch (11/17) / Loss 33.031136\n",
      "Epoch 8 / Batch (12/17) / Loss 31.563486\n",
      "Epoch 8 / Batch (13/17) / Loss 30.324961\n",
      "Epoch 8 / Batch (14/17) / Loss 31.177719\n",
      "Epoch 8 / Batch (15/17) / Loss 30.299995\n",
      "Epoch 8 / Batch (16/17) / Loss 29.691730\n",
      "Epoch 8 / Batch (17/17) / Loss 30.918211\n",
      "Epoch 8 complete, 11.561179876327515 seconds elapsed\n",
      "Epoch 9 starting..\n",
      "Epoch 9 / Batch (1/17) / Loss 28.877790\n",
      "Epoch 9 / Batch (2/17) / Loss 29.469305\n",
      "Epoch 9 / Batch (3/17) / Loss 27.996712\n",
      "Epoch 9 / Batch (4/17) / Loss 27.972685\n",
      "Epoch 9 / Batch (5/17) / Loss 28.596428\n",
      "Epoch 9 / Batch (6/17) / Loss 26.525145\n",
      "Epoch 9 / Batch (7/17) / Loss 25.324654\n",
      "Epoch 9 / Batch (8/17) / Loss 26.393110\n",
      "Epoch 9 / Batch (9/17) / Loss 25.167038\n",
      "Epoch 9 / Batch (10/17) / Loss 24.984625\n",
      "Epoch 9 / Batch (11/17) / Loss 24.376482\n",
      "Epoch 9 / Batch (12/17) / Loss 25.467052\n",
      "Epoch 9 / Batch (13/17) / Loss 25.258299\n",
      "Epoch 9 / Batch (14/17) / Loss 23.447735\n",
      "Epoch 9 / Batch (15/17) / Loss 24.509975\n",
      "Epoch 9 / Batch (16/17) / Loss 22.534531\n",
      "Epoch 9 / Batch (17/17) / Loss 23.364946\n",
      "Epoch 9 complete, 11.4458749294281 seconds elapsed\n",
      "Epoch 10 starting..\n",
      "Epoch 10 / Batch (1/17) / Loss 22.668886\n",
      "Epoch 10 / Batch (2/17) / Loss 22.801662\n",
      "Epoch 10 / Batch (3/17) / Loss 21.517099\n",
      "Epoch 10 / Batch (4/17) / Loss 20.935415\n",
      "Epoch 10 / Batch (5/17) / Loss 21.258501\n",
      "Epoch 10 / Batch (6/17) / Loss 20.004660\n",
      "Epoch 10 / Batch (7/17) / Loss 20.607216\n",
      "Epoch 10 / Batch (8/17) / Loss 20.567944\n",
      "Epoch 10 / Batch (9/17) / Loss 19.425203\n",
      "Epoch 10 / Batch (10/17) / Loss 19.584171\n",
      "Epoch 10 / Batch (11/17) / Loss 18.805466\n",
      "Epoch 10 / Batch (12/17) / Loss 18.458023\n",
      "Epoch 10 / Batch (13/17) / Loss 18.355219\n",
      "Epoch 10 / Batch (14/17) / Loss 18.203541\n",
      "Epoch 10 / Batch (15/17) / Loss 17.294716\n",
      "Epoch 10 / Batch (16/17) / Loss 18.163824\n",
      "Epoch 10 / Batch (17/17) / Loss 17.225592\n",
      "Epoch 10 complete, 11.29323697090149 seconds elapsed\n",
      "Epoch 11 starting..\n",
      "Epoch 11 / Batch (1/17) / Loss 15.727405\n",
      "Epoch 11 / Batch (2/17) / Loss 16.019157\n",
      "Epoch 11 / Batch (3/17) / Loss 14.791616\n",
      "Epoch 11 / Batch (4/17) / Loss 14.917812\n",
      "Epoch 11 / Batch (5/17) / Loss 14.965934\n",
      "Epoch 11 / Batch (6/17) / Loss 14.982980\n",
      "Epoch 11 / Batch (7/17) / Loss 14.217450\n",
      "Epoch 11 / Batch (8/17) / Loss 14.534347\n",
      "Epoch 11 / Batch (9/17) / Loss 14.158823\n",
      "Epoch 11 / Batch (10/17) / Loss 12.891499\n",
      "Epoch 11 / Batch (11/17) / Loss 12.980457\n",
      "Epoch 11 / Batch (12/17) / Loss 12.729042\n",
      "Epoch 11 / Batch (13/17) / Loss 12.352082\n",
      "Epoch 11 / Batch (14/17) / Loss 11.748609\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 11 / Batch (15/17) / Loss 10.893091\n",
      "Epoch 11 / Batch (16/17) / Loss 11.456593\n",
      "Epoch 11 / Batch (17/17) / Loss 10.373083\n",
      "Epoch 11 complete, 11.182086944580078 seconds elapsed\n",
      "Epoch 12 starting..\n",
      "Epoch 12 / Batch (1/17) / Loss 9.829419\n",
      "Epoch 12 / Batch (2/17) / Loss 9.914929\n",
      "Epoch 12 / Batch (3/17) / Loss 9.291359\n",
      "Epoch 12 / Batch (4/17) / Loss 9.657695\n",
      "Epoch 12 / Batch (5/17) / Loss 9.576180\n",
      "Epoch 12 / Batch (6/17) / Loss 8.906301\n",
      "Epoch 12 / Batch (7/17) / Loss 8.582533\n",
      "Epoch 12 / Batch (8/17) / Loss 8.017687\n",
      "Epoch 12 / Batch (9/17) / Loss 7.805200\n",
      "Epoch 12 / Batch (10/17) / Loss 7.868169\n",
      "Epoch 12 / Batch (11/17) / Loss 7.208621\n",
      "Epoch 12 / Batch (12/17) / Loss 7.170284\n",
      "Epoch 12 / Batch (13/17) / Loss 6.972737\n",
      "Epoch 12 / Batch (14/17) / Loss 6.504056\n",
      "Epoch 12 / Batch (15/17) / Loss 6.196991\n",
      "Epoch 12 / Batch (16/17) / Loss 6.307940\n",
      "Epoch 12 / Batch (17/17) / Loss 5.606591\n",
      "Epoch 12 complete, 11.582211971282959 seconds elapsed\n",
      "Epoch 13 starting..\n",
      "Epoch 13 / Batch (1/17) / Loss 5.536714\n",
      "Epoch 13 / Batch (2/17) / Loss 5.177177\n",
      "Epoch 13 / Batch (3/17) / Loss 4.643620\n",
      "Epoch 13 / Batch (4/17) / Loss 4.763430\n",
      "Epoch 13 / Batch (5/17) / Loss 3.984316\n",
      "Epoch 13 / Batch (6/17) / Loss 4.356893\n",
      "Epoch 13 / Batch (7/17) / Loss 3.844254\n",
      "Epoch 13 / Batch (8/17) / Loss 3.375185\n",
      "Epoch 13 / Batch (9/17) / Loss 3.480335\n",
      "Epoch 13 / Batch (10/17) / Loss 3.483876\n",
      "Epoch 13 / Batch (11/17) / Loss 3.445220\n",
      "Epoch 13 / Batch (12/17) / Loss 3.213898\n",
      "Epoch 13 / Batch (13/17) / Loss 3.442640\n",
      "Epoch 13 / Batch (14/17) / Loss 2.923095\n",
      "Epoch 13 / Batch (15/17) / Loss 2.677689\n",
      "Epoch 13 / Batch (16/17) / Loss 2.816588\n",
      "Epoch 13 / Batch (17/17) / Loss 2.985449\n",
      "Epoch 13 complete, 11.59763789176941 seconds elapsed\n",
      "Epoch 14 starting..\n",
      "Epoch 14 / Batch (1/17) / Loss 2.932592\n",
      "Epoch 14 / Batch (2/17) / Loss 2.490438\n",
      "Epoch 14 / Batch (3/17) / Loss 2.931769\n",
      "Epoch 14 / Batch (4/17) / Loss 2.687914\n",
      "Epoch 14 / Batch (5/17) / Loss 2.843840\n",
      "Epoch 14 / Batch (6/17) / Loss 2.444900\n",
      "Epoch 14 / Batch (7/17) / Loss 2.550777\n",
      "Epoch 14 / Batch (8/17) / Loss 2.489547\n",
      "Epoch 14 / Batch (9/17) / Loss 2.403363\n",
      "Epoch 14 / Batch (10/17) / Loss 2.253811\n",
      "Epoch 14 / Batch (11/17) / Loss 2.119659\n",
      "Epoch 14 / Batch (12/17) / Loss 2.157484\n",
      "Epoch 14 / Batch (13/17) / Loss 2.185498\n",
      "Epoch 14 / Batch (14/17) / Loss 2.200127\n",
      "Epoch 14 / Batch (15/17) / Loss 2.277685\n",
      "Epoch 14 / Batch (16/17) / Loss 2.033789\n",
      "Epoch 14 / Batch (17/17) / Loss 2.097018\n",
      "Epoch 14 complete, 11.38538908958435 seconds elapsed\n",
      "Epoch 15 starting..\n",
      "Epoch 15 / Batch (1/17) / Loss 2.127251\n",
      "Epoch 15 / Batch (2/17) / Loss 2.020219\n",
      "Epoch 15 / Batch (3/17) / Loss 2.136348\n",
      "Epoch 15 / Batch (4/17) / Loss 2.054206\n",
      "Epoch 15 / Batch (5/17) / Loss 2.098702\n",
      "Epoch 15 / Batch (6/17) / Loss 1.992432\n",
      "Epoch 15 / Batch (7/17) / Loss 1.993317\n",
      "Epoch 15 / Batch (8/17) / Loss 2.008871\n",
      "Epoch 15 / Batch (9/17) / Loss 1.994698\n",
      "Epoch 15 / Batch (10/17) / Loss 2.020777\n",
      "Epoch 15 / Batch (11/17) / Loss 1.911617\n",
      "Epoch 15 / Batch (12/17) / Loss 1.958822\n",
      "Epoch 15 / Batch (13/17) / Loss 1.965158\n",
      "Epoch 15 / Batch (14/17) / Loss 1.964103\n",
      "Epoch 15 / Batch (15/17) / Loss 2.029332\n",
      "Epoch 15 / Batch (16/17) / Loss 2.034461\n",
      "Epoch 15 / Batch (17/17) / Loss 2.049137\n",
      "Epoch 15 complete, 11.591316938400269 seconds elapsed\n",
      "Epoch 16 starting..\n",
      "Epoch 16 / Batch (1/17) / Loss 2.093146\n",
      "Epoch 16 / Batch (2/17) / Loss 2.014487\n",
      "Epoch 16 / Batch (3/17) / Loss 1.972086\n",
      "Epoch 16 / Batch (4/17) / Loss 2.023763\n",
      "Epoch 16 / Batch (5/17) / Loss 1.983642\n",
      "Epoch 16 / Batch (6/17) / Loss 2.011299\n",
      "Epoch 16 / Batch (7/17) / Loss 1.946330\n",
      "Epoch 16 / Batch (8/17) / Loss 1.973379\n",
      "Epoch 16 / Batch (9/17) / Loss 1.949450\n",
      "Epoch 16 / Batch (10/17) / Loss 1.973855\n",
      "Epoch 16 / Batch (11/17) / Loss 1.957628\n",
      "Epoch 16 / Batch (12/17) / Loss 1.939696\n",
      "Epoch 16 / Batch (13/17) / Loss 1.909089\n",
      "Epoch 16 / Batch (14/17) / Loss 1.937308\n",
      "Epoch 16 / Batch (15/17) / Loss 1.932664\n",
      "Epoch 16 / Batch (16/17) / Loss 1.937013\n",
      "Epoch 16 / Batch (17/17) / Loss 1.910212\n",
      "Epoch 16 complete, 11.706124067306519 seconds elapsed\n",
      "Epoch 17 starting..\n",
      "Epoch 17 / Batch (1/17) / Loss 1.935359\n",
      "Epoch 17 / Batch (2/17) / Loss 1.875443\n",
      "Epoch 17 / Batch (3/17) / Loss 1.873582\n",
      "Epoch 17 / Batch (4/17) / Loss 1.869281\n",
      "Epoch 17 / Batch (5/17) / Loss 1.864751\n",
      "Epoch 17 / Batch (6/17) / Loss 1.866821\n",
      "Epoch 17 / Batch (7/17) / Loss 1.907219\n",
      "Epoch 17 / Batch (8/17) / Loss 1.904050\n",
      "Epoch 17 / Batch (9/17) / Loss 1.869977\n",
      "Epoch 17 / Batch (10/17) / Loss 1.858422\n",
      "Epoch 17 / Batch (11/17) / Loss 1.865846\n",
      "Epoch 17 / Batch (12/17) / Loss 1.862402\n",
      "Epoch 17 / Batch (13/17) / Loss 1.860305\n",
      "Epoch 17 / Batch (14/17) / Loss 1.856950\n",
      "Epoch 17 / Batch (15/17) / Loss 1.882471\n",
      "Epoch 17 / Batch (16/17) / Loss 1.845809\n",
      "Epoch 17 / Batch (17/17) / Loss 1.854824\n",
      "Epoch 17 complete, 11.791397333145142 seconds elapsed\n",
      "Epoch 18 starting..\n",
      "Epoch 18 / Batch (1/17) / Loss 1.845654\n",
      "Epoch 18 / Batch (2/17) / Loss 1.851976\n",
      "Epoch 18 / Batch (3/17) / Loss 1.820044\n",
      "Epoch 18 / Batch (4/17) / Loss 1.831007\n",
      "Epoch 18 / Batch (5/17) / Loss 1.830999\n",
      "Epoch 18 / Batch (6/17) / Loss 1.805761\n",
      "Epoch 18 / Batch (7/17) / Loss 1.844635\n",
      "Epoch 18 / Batch (8/17) / Loss 1.817103\n",
      "Epoch 18 / Batch (9/17) / Loss 1.809008\n",
      "Epoch 18 / Batch (10/17) / Loss 1.806180\n",
      "Epoch 18 / Batch (11/17) / Loss 1.836184\n",
      "Epoch 18 / Batch (12/17) / Loss 1.800155\n",
      "Epoch 18 / Batch (13/17) / Loss 1.826612\n",
      "Epoch 18 / Batch (14/17) / Loss 1.797910\n",
      "Epoch 18 / Batch (15/17) / Loss 1.803306\n",
      "Epoch 18 / Batch (16/17) / Loss 1.785436\n",
      "Epoch 18 / Batch (17/17) / Loss 1.794242\n",
      "Epoch 18 complete, 12.881692171096802 seconds elapsed\n",
      "Epoch 19 starting..\n",
      "Epoch 19 / Batch (1/17) / Loss 1.823817\n",
      "Epoch 19 / Batch (2/17) / Loss 1.779259\n",
      "Epoch 19 / Batch (3/17) / Loss 1.815461\n",
      "Epoch 19 / Batch (4/17) / Loss 1.809270\n",
      "Epoch 19 / Batch (5/17) / Loss 1.805270\n",
      "Epoch 19 / Batch (6/17) / Loss 1.815256\n",
      "Epoch 19 / Batch (7/17) / Loss 1.803304\n",
      "Epoch 19 / Batch (8/17) / Loss 1.818812\n",
      "Epoch 19 / Batch (9/17) / Loss 1.809170\n",
      "Epoch 19 / Batch (10/17) / Loss 1.799449\n",
      "Epoch 19 / Batch (11/17) / Loss 1.773603\n",
      "Epoch 19 / Batch (12/17) / Loss 1.788986\n",
      "Epoch 19 / Batch (13/17) / Loss 1.775160\n",
      "Epoch 19 / Batch (14/17) / Loss 1.809746\n",
      "Epoch 19 / Batch (15/17) / Loss 1.789145\n",
      "Epoch 19 / Batch (16/17) / Loss 1.780032\n",
      "Epoch 19 / Batch (17/17) / Loss 1.783320\n",
      "Epoch 19 complete, 13.362638711929321 seconds elapsed\n",
      "Epoch 20 starting..\n",
      "Epoch 20 / Batch (1/17) / Loss 1.793627\n",
      "Epoch 20 / Batch (2/17) / Loss 1.768111\n",
      "Epoch 20 / Batch (3/17) / Loss 1.814574\n",
      "Epoch 20 / Batch (4/17) / Loss 1.753380\n",
      "Epoch 20 / Batch (5/17) / Loss 1.793041\n",
      "Epoch 20 / Batch (6/17) / Loss 1.813498\n",
      "Epoch 20 / Batch (7/17) / Loss 1.798944\n",
      "Epoch 20 / Batch (8/17) / Loss 1.793713\n",
      "Epoch 20 / Batch (9/17) / Loss 1.797327\n",
      "Epoch 20 / Batch (10/17) / Loss 1.786093\n",
      "Epoch 20 / Batch (11/17) / Loss 1.787656\n",
      "Epoch 20 / Batch (12/17) / Loss 1.771886\n",
      "Epoch 20 / Batch (13/17) / Loss 1.757880\n",
      "Epoch 20 / Batch (14/17) / Loss 1.791387\n",
      "Epoch 20 / Batch (15/17) / Loss 1.780975\n",
      "Epoch 20 / Batch (16/17) / Loss 1.783176\n",
      "Epoch 20 / Batch (17/17) / Loss 1.799418\n",
      "Epoch 20 complete, 14.099467039108276 seconds elapsed\n",
      "Epoch 21 starting..\n",
      "Epoch 21 / Batch (1/17) / Loss 1.792656\n",
      "Epoch 21 / Batch (2/17) / Loss 1.795489\n",
      "Epoch 21 / Batch (3/17) / Loss 1.770981\n",
      "Epoch 21 / Batch (4/17) / Loss 1.788019\n",
      "Epoch 21 / Batch (5/17) / Loss 1.782752\n",
      "Epoch 21 / Batch (6/17) / Loss 1.763991\n",
      "Epoch 21 / Batch (7/17) / Loss 1.771862\n",
      "Epoch 21 / Batch (8/17) / Loss 1.763482\n",
      "Epoch 21 / Batch (9/17) / Loss 1.800532\n",
      "Epoch 21 / Batch (10/17) / Loss 1.820833\n",
      "Epoch 21 / Batch (11/17) / Loss 1.795488\n",
      "Epoch 21 / Batch (12/17) / Loss 1.799342\n",
      "Epoch 21 / Batch (13/17) / Loss 1.761933\n",
      "Epoch 21 / Batch (14/17) / Loss 1.790915\n",
      "Epoch 21 / Batch (15/17) / Loss 1.759300\n",
      "Epoch 21 / Batch (16/17) / Loss 1.749975\n",
      "Epoch 21 / Batch (17/17) / Loss 1.801743\n",
      "Epoch 21 complete, 14.442699909210205 seconds elapsed\n",
      "Epoch 22 starting..\n",
      "Epoch 22 / Batch (1/17) / Loss 1.771073\n",
      "Epoch 22 / Batch (2/17) / Loss 1.767506\n",
      "Epoch 22 / Batch (3/17) / Loss 1.771035\n",
      "Epoch 22 / Batch (4/17) / Loss 1.752570\n",
      "Epoch 22 / Batch (5/17) / Loss 1.786237\n",
      "Epoch 22 / Batch (6/17) / Loss 1.786143\n",
      "Epoch 22 / Batch (7/17) / Loss 1.771045\n",
      "Epoch 22 / Batch (8/17) / Loss 1.779997\n",
      "Epoch 22 / Batch (9/17) / Loss 1.788972\n",
      "Epoch 22 / Batch (10/17) / Loss 1.759544\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 22 / Batch (11/17) / Loss 1.764510\n",
      "Epoch 22 / Batch (12/17) / Loss 1.744342\n",
      "Epoch 22 / Batch (13/17) / Loss 1.778694\n",
      "Epoch 22 / Batch (14/17) / Loss 1.774147\n",
      "Epoch 22 / Batch (15/17) / Loss 1.783040\n",
      "Epoch 22 / Batch (16/17) / Loss 1.774013\n",
      "Epoch 22 / Batch (17/17) / Loss 1.755073\n",
      "Epoch 22 complete, 13.780434846878052 seconds elapsed\n",
      "Epoch 23 starting..\n",
      "Epoch 23 / Batch (1/17) / Loss 1.761404\n",
      "Epoch 23 / Batch (2/17) / Loss 1.793237\n",
      "Epoch 23 / Batch (3/17) / Loss 1.779007\n",
      "Epoch 23 / Batch (4/17) / Loss 1.779112\n",
      "Epoch 23 / Batch (5/17) / Loss 1.770565\n",
      "Epoch 23 / Batch (6/17) / Loss 1.769209\n",
      "Epoch 23 / Batch (7/17) / Loss 1.753875\n",
      "Epoch 23 / Batch (8/17) / Loss 1.749051\n",
      "Epoch 23 / Batch (9/17) / Loss 1.761039\n",
      "Epoch 23 / Batch (10/17) / Loss 1.743680\n",
      "Epoch 23 / Batch (11/17) / Loss 1.749750\n",
      "Epoch 23 / Batch (12/17) / Loss 1.743833\n",
      "Epoch 23 / Batch (13/17) / Loss 1.741233\n",
      "Epoch 23 / Batch (14/17) / Loss 1.754305\n",
      "Epoch 23 / Batch (15/17) / Loss 1.742612\n",
      "Epoch 23 / Batch (16/17) / Loss 1.744159\n",
      "Epoch 23 / Batch (17/17) / Loss 1.756656\n",
      "Epoch 23 complete, 13.009907245635986 seconds elapsed\n",
      "Epoch 24 starting..\n",
      "Epoch 24 / Batch (1/17) / Loss 1.749492\n",
      "Epoch 24 / Batch (2/17) / Loss 1.740361\n",
      "Epoch 24 / Batch (3/17) / Loss 1.764565\n",
      "Epoch 24 / Batch (4/17) / Loss 1.733433\n",
      "Epoch 24 / Batch (5/17) / Loss 1.740840\n",
      "Epoch 24 / Batch (6/17) / Loss 1.734647\n",
      "Epoch 24 / Batch (7/17) / Loss 1.720979\n",
      "Epoch 24 / Batch (8/17) / Loss 1.761929\n",
      "Epoch 24 / Batch (9/17) / Loss 1.735859\n",
      "Epoch 24 / Batch (10/17) / Loss 1.748967\n",
      "Epoch 24 / Batch (11/17) / Loss 1.730057\n",
      "Epoch 24 / Batch (12/17) / Loss 1.751913\n",
      "Epoch 24 / Batch (13/17) / Loss 1.730474\n",
      "Epoch 24 / Batch (14/17) / Loss 1.751113\n",
      "Epoch 24 / Batch (15/17) / Loss 1.751253\n",
      "Epoch 24 / Batch (16/17) / Loss 1.758998\n",
      "Epoch 24 / Batch (17/17) / Loss 1.722917\n",
      "Epoch 24 complete, 13.169474840164185 seconds elapsed\n",
      "Epoch 25 starting..\n",
      "Epoch 25 / Batch (1/17) / Loss 1.751658\n",
      "Epoch 25 / Batch (2/17) / Loss 1.748187\n",
      "Epoch 25 / Batch (3/17) / Loss 1.732587\n",
      "Epoch 25 / Batch (4/17) / Loss 1.723606\n",
      "Epoch 25 / Batch (5/17) / Loss 1.724418\n",
      "Epoch 25 / Batch (6/17) / Loss 1.730909\n",
      "Epoch 25 / Batch (7/17) / Loss 1.712951\n",
      "Epoch 25 / Batch (8/17) / Loss 1.743131\n",
      "Epoch 25 / Batch (9/17) / Loss 1.709291\n",
      "Epoch 25 / Batch (10/17) / Loss 1.733622\n",
      "Epoch 25 / Batch (11/17) / Loss 1.736617\n",
      "Epoch 25 / Batch (12/17) / Loss 1.751430\n",
      "Epoch 25 / Batch (13/17) / Loss 1.709576\n",
      "Epoch 25 / Batch (14/17) / Loss 1.737791\n",
      "Epoch 25 / Batch (15/17) / Loss 1.734658\n",
      "Epoch 25 / Batch (16/17) / Loss 1.731083\n",
      "Epoch 25 / Batch (17/17) / Loss 1.742078\n",
      "Epoch 25 complete, 13.389103174209595 seconds elapsed\n",
      "Epoch 26 starting..\n",
      "Epoch 26 / Batch (1/17) / Loss 1.739387\n",
      "Epoch 26 / Batch (2/17) / Loss 1.727146\n",
      "Epoch 26 / Batch (3/17) / Loss 1.750782\n",
      "Epoch 26 / Batch (4/17) / Loss 1.707998\n",
      "Epoch 26 / Batch (5/17) / Loss 1.721432\n",
      "Epoch 26 / Batch (6/17) / Loss 1.735566\n",
      "Epoch 26 / Batch (7/17) / Loss 1.737834\n",
      "Epoch 26 / Batch (8/17) / Loss 1.738270\n",
      "Epoch 26 / Batch (9/17) / Loss 1.734012\n",
      "Epoch 26 / Batch (10/17) / Loss 1.740612\n",
      "Epoch 26 / Batch (11/17) / Loss 1.733912\n",
      "Epoch 26 / Batch (12/17) / Loss 1.707920\n",
      "Epoch 26 / Batch (13/17) / Loss 1.721379\n",
      "Epoch 26 / Batch (14/17) / Loss 1.742118\n",
      "Epoch 26 / Batch (15/17) / Loss 1.725551\n",
      "Epoch 26 / Batch (16/17) / Loss 1.742738\n",
      "Epoch 26 / Batch (17/17) / Loss 1.739947\n",
      "Epoch 26 complete, 13.39096188545227 seconds elapsed\n",
      "Epoch 27 starting..\n",
      "Epoch 27 / Batch (1/17) / Loss 1.738588\n",
      "Epoch 27 / Batch (2/17) / Loss 1.729296\n",
      "Epoch 27 / Batch (3/17) / Loss 1.721552\n",
      "Epoch 27 / Batch (4/17) / Loss 1.722377\n",
      "Epoch 27 / Batch (5/17) / Loss 1.708588\n",
      "Epoch 27 / Batch (6/17) / Loss 1.713838\n",
      "Epoch 27 / Batch (7/17) / Loss 1.718535\n",
      "Epoch 27 / Batch (8/17) / Loss 1.730721\n",
      "Epoch 27 / Batch (9/17) / Loss 1.725866\n",
      "Epoch 27 / Batch (10/17) / Loss 1.713316\n",
      "Epoch 27 / Batch (11/17) / Loss 1.719987\n",
      "Epoch 27 / Batch (12/17) / Loss 1.722442\n",
      "Epoch 27 / Batch (13/17) / Loss 1.669687\n",
      "Epoch 27 / Batch (14/17) / Loss 1.716516\n",
      "Epoch 27 / Batch (15/17) / Loss 1.724891\n",
      "Epoch 27 / Batch (16/17) / Loss 1.700842\n",
      "Epoch 27 / Batch (17/17) / Loss 1.750221\n",
      "Epoch 27 complete, 13.331692934036255 seconds elapsed\n",
      "Epoch 28 starting..\n",
      "Epoch 28 / Batch (1/17) / Loss 1.708675\n",
      "Epoch 28 / Batch (2/17) / Loss 1.719538\n",
      "Epoch 28 / Batch (3/17) / Loss 1.718097\n",
      "Epoch 28 / Batch (4/17) / Loss 1.738132\n",
      "Epoch 28 / Batch (5/17) / Loss 1.728583\n",
      "Epoch 28 / Batch (6/17) / Loss 1.714974\n",
      "Epoch 28 / Batch (7/17) / Loss 1.715622\n",
      "Epoch 28 / Batch (8/17) / Loss 1.715958\n",
      "Epoch 28 / Batch (9/17) / Loss 1.713956\n",
      "Epoch 28 / Batch (10/17) / Loss 1.728291\n",
      "Epoch 28 / Batch (11/17) / Loss 1.715280\n",
      "Epoch 28 / Batch (12/17) / Loss 1.730662\n",
      "Epoch 28 / Batch (13/17) / Loss 1.742337\n",
      "Epoch 28 / Batch (14/17) / Loss 1.719487\n",
      "Epoch 28 / Batch (15/17) / Loss 1.719563\n",
      "Epoch 28 / Batch (16/17) / Loss 1.703225\n",
      "Epoch 28 / Batch (17/17) / Loss 1.702621\n",
      "Epoch 28 complete, 14.469742059707642 seconds elapsed\n",
      "Epoch 29 starting..\n",
      "Epoch 29 / Batch (1/17) / Loss 1.737707\n",
      "Epoch 29 / Batch (2/17) / Loss 1.719047\n",
      "Epoch 29 / Batch (3/17) / Loss 1.735378\n",
      "Epoch 29 / Batch (4/17) / Loss 1.712429\n",
      "Epoch 29 / Batch (5/17) / Loss 1.733383\n",
      "Epoch 29 / Batch (6/17) / Loss 1.724583\n",
      "Epoch 29 / Batch (7/17) / Loss 1.732526\n",
      "Epoch 29 / Batch (8/17) / Loss 1.730975\n",
      "Epoch 29 / Batch (9/17) / Loss 1.695748\n",
      "Epoch 29 / Batch (10/17) / Loss 1.723432\n",
      "Epoch 29 / Batch (11/17) / Loss 1.725592\n",
      "Epoch 29 / Batch (12/17) / Loss 1.718608\n",
      "Epoch 29 / Batch (13/17) / Loss 1.732820\n",
      "Epoch 29 / Batch (14/17) / Loss 1.717926\n",
      "Epoch 29 / Batch (15/17) / Loss 1.701529\n",
      "Epoch 29 / Batch (16/17) / Loss 1.747056\n",
      "Epoch 29 / Batch (17/17) / Loss 1.709515\n",
      "Epoch 29 complete, 16.55297017097473 seconds elapsed\n",
      "Epoch 30 starting..\n",
      "Epoch 30 / Batch (1/17) / Loss 1.721995\n",
      "Epoch 30 / Batch (2/17) / Loss 1.721470\n",
      "Epoch 30 / Batch (3/17) / Loss 1.717329\n",
      "Epoch 30 / Batch (4/17) / Loss 1.729489\n",
      "Epoch 30 / Batch (5/17) / Loss 1.727882\n",
      "Epoch 30 / Batch (6/17) / Loss 1.713757\n",
      "Epoch 30 / Batch (7/17) / Loss 1.717776\n",
      "Epoch 30 / Batch (8/17) / Loss 1.726444\n",
      "Epoch 30 / Batch (9/17) / Loss 1.711502\n",
      "Epoch 30 / Batch (10/17) / Loss 1.718787\n",
      "Epoch 30 / Batch (11/17) / Loss 1.722828\n",
      "Epoch 30 / Batch (12/17) / Loss 1.730620\n",
      "Epoch 30 / Batch (13/17) / Loss 1.708663\n",
      "Epoch 30 / Batch (14/17) / Loss 1.720300\n",
      "Epoch 30 / Batch (15/17) / Loss 1.723922\n",
      "Epoch 30 / Batch (16/17) / Loss 1.730332\n",
      "Epoch 30 / Batch (17/17) / Loss 1.719331\n",
      "Epoch 30 complete, 16.314176082611084 seconds elapsed\n",
      "Epoch 31 starting..\n",
      "Epoch 31 / Batch (1/17) / Loss 1.713760\n",
      "Epoch 31 / Batch (2/17) / Loss 1.710237\n",
      "Epoch 31 / Batch (3/17) / Loss 1.718039\n",
      "Epoch 31 / Batch (4/17) / Loss 1.707437\n",
      "Epoch 31 / Batch (5/17) / Loss 1.707717\n",
      "Epoch 31 / Batch (6/17) / Loss 1.700356\n",
      "Epoch 31 / Batch (7/17) / Loss 1.709437\n",
      "Epoch 31 / Batch (8/17) / Loss 1.708212\n",
      "Epoch 31 / Batch (9/17) / Loss 1.695550\n",
      "Epoch 31 / Batch (10/17) / Loss 1.713977\n",
      "Epoch 31 / Batch (11/17) / Loss 1.693291\n",
      "Epoch 31 / Batch (12/17) / Loss 1.712032\n",
      "Epoch 31 / Batch (13/17) / Loss 1.720200\n",
      "Epoch 31 / Batch (14/17) / Loss 1.718070\n",
      "Epoch 31 / Batch (15/17) / Loss 1.712011\n",
      "Epoch 31 / Batch (16/17) / Loss 1.715991\n",
      "Epoch 31 / Batch (17/17) / Loss 1.721354\n",
      "Epoch 31 complete, 16.566627025604248 seconds elapsed\n",
      "Epoch 32 starting..\n",
      "Epoch 32 / Batch (1/17) / Loss 1.700432\n",
      "Epoch 32 / Batch (2/17) / Loss 1.722108\n",
      "Epoch 32 / Batch (3/17) / Loss 1.720418\n",
      "Epoch 32 / Batch (4/17) / Loss 1.699609\n",
      "Epoch 32 / Batch (5/17) / Loss 1.708991\n",
      "Epoch 32 / Batch (6/17) / Loss 1.697243\n",
      "Epoch 32 / Batch (7/17) / Loss 1.706711\n",
      "Epoch 32 / Batch (8/17) / Loss 1.714127\n",
      "Epoch 32 / Batch (9/17) / Loss 1.717331\n",
      "Epoch 32 / Batch (10/17) / Loss 1.709560\n",
      "Epoch 32 / Batch (11/17) / Loss 1.715687\n",
      "Epoch 32 / Batch (12/17) / Loss 1.728855\n",
      "Epoch 32 / Batch (13/17) / Loss 1.704213\n",
      "Epoch 32 / Batch (14/17) / Loss 1.728407\n",
      "Epoch 32 / Batch (15/17) / Loss 1.723170\n",
      "Epoch 32 / Batch (16/17) / Loss 1.717136\n",
      "Epoch 32 / Batch (17/17) / Loss 1.720979\n",
      "Epoch 32 complete, 14.08050799369812 seconds elapsed\n",
      "Epoch 33 starting..\n",
      "Epoch 33 / Batch (1/17) / Loss 1.712584\n",
      "Epoch 33 / Batch (2/17) / Loss 1.709962\n",
      "Epoch 33 / Batch (3/17) / Loss 1.720954\n",
      "Epoch 33 / Batch (4/17) / Loss 1.706393\n",
      "Epoch 33 / Batch (5/17) / Loss 1.708360\n",
      "Epoch 33 / Batch (6/17) / Loss 1.716872\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 33 / Batch (7/17) / Loss 1.713011\n",
      "Epoch 33 / Batch (8/17) / Loss 1.716750\n",
      "Epoch 33 / Batch (9/17) / Loss 1.718077\n",
      "Epoch 33 / Batch (10/17) / Loss 1.705534\n",
      "Epoch 33 / Batch (11/17) / Loss 1.731735\n",
      "Epoch 33 / Batch (12/17) / Loss 1.717885\n",
      "Epoch 33 / Batch (13/17) / Loss 1.680782\n",
      "Epoch 33 / Batch (14/17) / Loss 1.721704\n",
      "Epoch 33 / Batch (15/17) / Loss 1.726072\n",
      "Epoch 33 / Batch (16/17) / Loss 1.695835\n",
      "Epoch 33 / Batch (17/17) / Loss 1.712241\n",
      "Epoch 33 complete, 13.905726909637451 seconds elapsed\n",
      "Epoch 34 starting..\n",
      "Epoch 34 / Batch (1/17) / Loss 1.721817\n",
      "Epoch 34 / Batch (2/17) / Loss 1.706303\n",
      "Epoch 34 / Batch (3/17) / Loss 1.723582\n",
      "Epoch 34 / Batch (4/17) / Loss 1.700979\n",
      "Epoch 34 / Batch (5/17) / Loss 1.716779\n",
      "Epoch 34 / Batch (6/17) / Loss 1.695520\n",
      "Epoch 34 / Batch (7/17) / Loss 1.703618\n",
      "Epoch 34 / Batch (8/17) / Loss 1.707929\n",
      "Epoch 34 / Batch (9/17) / Loss 1.692410\n",
      "Epoch 34 / Batch (10/17) / Loss 1.722330\n",
      "Epoch 34 / Batch (11/17) / Loss 1.711591\n",
      "Epoch 34 / Batch (12/17) / Loss 1.720562\n",
      "Epoch 34 / Batch (13/17) / Loss 1.698151\n",
      "Epoch 34 / Batch (14/17) / Loss 1.710979\n",
      "Epoch 34 / Batch (15/17) / Loss 1.711328\n",
      "Epoch 34 / Batch (16/17) / Loss 1.709732\n",
      "Epoch 34 / Batch (17/17) / Loss 1.715285\n",
      "Epoch 34 complete, 12.735618829727173 seconds elapsed\n",
      "Epoch 35 starting..\n",
      "Epoch 35 / Batch (1/17) / Loss 1.718357\n",
      "Epoch 35 / Batch (2/17) / Loss 1.705677\n",
      "Epoch 35 / Batch (3/17) / Loss 1.710707\n",
      "Epoch 35 / Batch (4/17) / Loss 1.709963\n",
      "Epoch 35 / Batch (5/17) / Loss 1.704929\n",
      "Epoch 35 / Batch (6/17) / Loss 1.699983\n",
      "Epoch 35 / Batch (7/17) / Loss 1.691019\n",
      "Epoch 35 / Batch (8/17) / Loss 1.702098\n",
      "Epoch 35 / Batch (9/17) / Loss 1.703856\n",
      "Epoch 35 / Batch (10/17) / Loss 1.711970\n",
      "Epoch 35 / Batch (11/17) / Loss 1.701631\n",
      "Epoch 35 / Batch (12/17) / Loss 1.719796\n",
      "Epoch 35 / Batch (13/17) / Loss 1.689083\n",
      "Epoch 35 / Batch (14/17) / Loss 1.718768\n",
      "Epoch 35 / Batch (15/17) / Loss 1.704260\n",
      "Epoch 35 / Batch (16/17) / Loss 1.713945\n",
      "Epoch 35 / Batch (17/17) / Loss 1.705923\n",
      "Epoch 35 complete, 13.717657089233398 seconds elapsed\n",
      "Epoch 36 starting..\n",
      "Epoch 36 / Batch (1/17) / Loss 1.702510\n",
      "Epoch 36 / Batch (2/17) / Loss 1.702678\n",
      "Epoch 36 / Batch (3/17) / Loss 1.704552\n",
      "Epoch 36 / Batch (4/17) / Loss 1.703875\n",
      "Epoch 36 / Batch (5/17) / Loss 1.677826\n",
      "Epoch 36 / Batch (6/17) / Loss 1.707109\n",
      "Epoch 36 / Batch (7/17) / Loss 1.684029\n",
      "Epoch 36 / Batch (8/17) / Loss 1.700378\n",
      "Epoch 36 / Batch (9/17) / Loss 1.686322\n",
      "Epoch 36 / Batch (10/17) / Loss 1.692343\n",
      "Epoch 36 / Batch (11/17) / Loss 1.701366\n",
      "Epoch 36 / Batch (12/17) / Loss 1.729361\n",
      "Epoch 36 / Batch (13/17) / Loss 1.692279\n",
      "Epoch 36 / Batch (14/17) / Loss 1.700010\n",
      "Epoch 36 / Batch (15/17) / Loss 1.719460\n",
      "Epoch 36 / Batch (16/17) / Loss 1.717773\n",
      "Epoch 36 / Batch (17/17) / Loss 1.713554\n",
      "Epoch 36 complete, 12.128651142120361 seconds elapsed\n",
      "Epoch 37 starting..\n",
      "Epoch 37 / Batch (1/17) / Loss 1.718347\n",
      "Epoch 37 / Batch (2/17) / Loss 1.708208\n",
      "Epoch 37 / Batch (3/17) / Loss 1.704082\n",
      "Epoch 37 / Batch (4/17) / Loss 1.702346\n",
      "Epoch 37 / Batch (5/17) / Loss 1.694506\n",
      "Epoch 37 / Batch (6/17) / Loss 1.703999\n",
      "Epoch 37 / Batch (7/17) / Loss 1.698282\n",
      "Epoch 37 / Batch (8/17) / Loss 1.707852\n",
      "Epoch 37 / Batch (9/17) / Loss 1.676771\n",
      "Epoch 37 / Batch (10/17) / Loss 1.688673\n",
      "Epoch 37 / Batch (11/17) / Loss 1.692632\n",
      "Epoch 37 / Batch (12/17) / Loss 1.674296\n",
      "Epoch 37 / Batch (13/17) / Loss 1.700577\n",
      "Epoch 37 / Batch (14/17) / Loss 1.709276\n",
      "Epoch 37 / Batch (15/17) / Loss 1.697683\n",
      "Epoch 37 / Batch (16/17) / Loss 1.706888\n",
      "Epoch 37 / Batch (17/17) / Loss 1.714611\n",
      "Epoch 37 complete, 12.892301082611084 seconds elapsed\n",
      "Epoch 38 starting..\n",
      "Epoch 38 / Batch (1/17) / Loss 1.692563\n",
      "Epoch 38 / Batch (2/17) / Loss 1.692016\n",
      "Epoch 38 / Batch (3/17) / Loss 1.679494\n",
      "Epoch 38 / Batch (4/17) / Loss 1.664910\n",
      "Epoch 38 / Batch (5/17) / Loss 1.693268\n",
      "Epoch 38 / Batch (6/17) / Loss 1.705332\n",
      "Epoch 38 / Batch (7/17) / Loss 1.679582\n",
      "Epoch 38 / Batch (8/17) / Loss 1.693482\n",
      "Epoch 38 / Batch (9/17) / Loss 1.689187\n",
      "Epoch 38 / Batch (10/17) / Loss 1.705718\n",
      "Epoch 38 / Batch (11/17) / Loss 1.684024\n",
      "Epoch 38 / Batch (12/17) / Loss 1.692228\n",
      "Epoch 38 / Batch (13/17) / Loss 1.698780\n",
      "Epoch 38 / Batch (14/17) / Loss 1.688426\n",
      "Epoch 38 / Batch (15/17) / Loss 1.705869\n",
      "Epoch 38 / Batch (16/17) / Loss 1.685789\n",
      "Epoch 38 / Batch (17/17) / Loss 1.695529\n",
      "Epoch 38 complete, 12.644124031066895 seconds elapsed\n",
      "Epoch 39 starting..\n",
      "Epoch 39 / Batch (1/17) / Loss 1.693458\n",
      "Epoch 39 / Batch (2/17) / Loss 1.708943\n",
      "Epoch 39 / Batch (3/17) / Loss 1.689473\n",
      "Epoch 39 / Batch (4/17) / Loss 1.691641\n",
      "Epoch 39 / Batch (5/17) / Loss 1.683828\n",
      "Epoch 39 / Batch (6/17) / Loss 1.692045\n",
      "Epoch 39 / Batch (7/17) / Loss 1.704771\n",
      "Epoch 39 / Batch (8/17) / Loss 1.694256\n",
      "Epoch 39 / Batch (9/17) / Loss 1.698082\n",
      "Epoch 39 / Batch (10/17) / Loss 1.702910\n",
      "Epoch 39 / Batch (11/17) / Loss 1.694023\n",
      "Epoch 39 / Batch (12/17) / Loss 1.700508\n",
      "Epoch 39 / Batch (13/17) / Loss 1.709911\n",
      "Epoch 39 / Batch (14/17) / Loss 1.696794\n",
      "Epoch 39 / Batch (15/17) / Loss 1.688734\n",
      "Epoch 39 / Batch (16/17) / Loss 1.720561\n",
      "Epoch 39 / Batch (17/17) / Loss 1.700347\n",
      "Epoch 39 complete, 11.816447019577026 seconds elapsed\n",
      "Epoch 40 starting..\n",
      "Epoch 40 / Batch (1/17) / Loss 1.701732\n",
      "Epoch 40 / Batch (2/17) / Loss 1.692680\n",
      "Epoch 40 / Batch (3/17) / Loss 1.707118\n",
      "Epoch 40 / Batch (4/17) / Loss 1.686219\n",
      "Epoch 40 / Batch (5/17) / Loss 1.690695\n",
      "Epoch 40 / Batch (6/17) / Loss 1.674589\n",
      "Epoch 40 / Batch (7/17) / Loss 1.697980\n",
      "Epoch 40 / Batch (8/17) / Loss 1.695121\n",
      "Epoch 40 / Batch (9/17) / Loss 1.686590\n",
      "Epoch 40 / Batch (10/17) / Loss 1.676527\n",
      "Epoch 40 / Batch (11/17) / Loss 1.698967\n",
      "Epoch 40 / Batch (12/17) / Loss 1.701540\n",
      "Epoch 40 / Batch (13/17) / Loss 1.700904\n",
      "Epoch 40 / Batch (14/17) / Loss 1.694901\n",
      "Epoch 40 / Batch (15/17) / Loss 1.709908\n",
      "Epoch 40 / Batch (16/17) / Loss 1.700496\n",
      "Epoch 40 / Batch (17/17) / Loss 1.679121\n",
      "Epoch 40 complete, 11.485016107559204 seconds elapsed\n",
      "Epoch 41 starting..\n",
      "Epoch 41 / Batch (1/17) / Loss 1.705726\n",
      "Epoch 41 / Batch (2/17) / Loss 1.679400\n",
      "Epoch 41 / Batch (3/17) / Loss 1.679857\n",
      "Epoch 41 / Batch (4/17) / Loss 1.694953\n",
      "Epoch 41 / Batch (5/17) / Loss 1.704634\n",
      "Epoch 41 / Batch (6/17) / Loss 1.696073\n",
      "Epoch 41 / Batch (7/17) / Loss 1.691380\n",
      "Epoch 41 / Batch (8/17) / Loss 1.707922\n",
      "Epoch 41 / Batch (9/17) / Loss 1.702839\n",
      "Epoch 41 / Batch (10/17) / Loss 1.694065\n",
      "Epoch 41 / Batch (11/17) / Loss 1.701242\n",
      "Epoch 41 / Batch (12/17) / Loss 1.699505\n",
      "Epoch 41 / Batch (13/17) / Loss 1.683150\n",
      "Epoch 41 / Batch (14/17) / Loss 1.702547\n",
      "Epoch 41 / Batch (15/17) / Loss 1.704773\n",
      "Epoch 41 / Batch (16/17) / Loss 1.702298\n",
      "Epoch 41 / Batch (17/17) / Loss 1.689858\n",
      "Epoch 41 complete, 11.907846927642822 seconds elapsed\n",
      "Epoch 42 starting..\n",
      "Epoch 42 / Batch (1/17) / Loss 1.700607\n",
      "Epoch 42 / Batch (2/17) / Loss 1.693373\n",
      "Epoch 42 / Batch (3/17) / Loss 1.696576\n",
      "Epoch 42 / Batch (4/17) / Loss 1.690523\n",
      "Epoch 42 / Batch (5/17) / Loss 1.699455\n",
      "Epoch 42 / Batch (6/17) / Loss 1.700936\n",
      "Epoch 42 / Batch (7/17) / Loss 1.698315\n",
      "Epoch 42 / Batch (8/17) / Loss 1.695161\n",
      "Epoch 42 / Batch (9/17) / Loss 1.696867\n",
      "Epoch 42 / Batch (10/17) / Loss 1.704349\n",
      "Epoch 42 / Batch (11/17) / Loss 1.702706\n",
      "Epoch 42 / Batch (12/17) / Loss 1.708594\n",
      "Epoch 42 / Batch (13/17) / Loss 1.698348\n",
      "Epoch 42 / Batch (14/17) / Loss 1.693378\n",
      "Epoch 42 / Batch (15/17) / Loss 1.688495\n",
      "Epoch 42 / Batch (16/17) / Loss 1.689026\n",
      "Epoch 42 / Batch (17/17) / Loss 1.701253\n",
      "Epoch 42 complete, 11.832143783569336 seconds elapsed\n",
      "Epoch 43 starting..\n",
      "Epoch 43 / Batch (1/17) / Loss 1.710460\n",
      "Epoch 43 / Batch (2/17) / Loss 1.695433\n",
      "Epoch 43 / Batch (3/17) / Loss 1.702863\n",
      "Epoch 43 / Batch (4/17) / Loss 1.674215\n",
      "Epoch 43 / Batch (5/17) / Loss 1.686179\n",
      "Epoch 43 / Batch (6/17) / Loss 1.694947\n",
      "Epoch 43 / Batch (7/17) / Loss 1.692218\n",
      "Epoch 43 / Batch (8/17) / Loss 1.701642\n",
      "Epoch 43 / Batch (9/17) / Loss 1.703865\n",
      "Epoch 43 / Batch (10/17) / Loss 1.677774\n",
      "Epoch 43 / Batch (11/17) / Loss 1.685026\n",
      "Epoch 43 / Batch (12/17) / Loss 1.670283\n",
      "Epoch 43 / Batch (13/17) / Loss 1.685348\n",
      "Epoch 43 / Batch (14/17) / Loss 1.679006\n",
      "Epoch 43 / Batch (15/17) / Loss 1.665129\n",
      "Epoch 43 / Batch (16/17) / Loss 1.689036\n",
      "Epoch 43 / Batch (17/17) / Loss 1.682261\n",
      "Epoch 43 complete, 12.089302778244019 seconds elapsed\n",
      "Epoch 44 starting..\n",
      "Epoch 44 / Batch (1/17) / Loss 1.680895\n",
      "Epoch 44 / Batch (2/17) / Loss 1.693895\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 44 / Batch (3/17) / Loss 1.681794\n",
      "Epoch 44 / Batch (4/17) / Loss 1.676580\n",
      "Epoch 44 / Batch (5/17) / Loss 1.700832\n",
      "Epoch 44 / Batch (6/17) / Loss 1.701390\n",
      "Epoch 44 / Batch (7/17) / Loss 1.679261\n",
      "Epoch 44 / Batch (8/17) / Loss 1.686363\n",
      "Epoch 44 / Batch (9/17) / Loss 1.700104\n",
      "Epoch 44 / Batch (10/17) / Loss 1.698209\n",
      "Epoch 44 / Batch (11/17) / Loss 1.706776\n",
      "Epoch 44 / Batch (12/17) / Loss 1.706330\n",
      "Epoch 44 / Batch (13/17) / Loss 1.678398\n",
      "Epoch 44 / Batch (14/17) / Loss 1.680387\n",
      "Epoch 44 / Batch (15/17) / Loss 1.688161\n",
      "Epoch 44 / Batch (16/17) / Loss 1.701527\n",
      "Epoch 44 / Batch (17/17) / Loss 1.703416\n",
      "Epoch 44 complete, 11.607014179229736 seconds elapsed\n",
      "Epoch 45 starting..\n",
      "Epoch 45 / Batch (1/17) / Loss 1.682180\n",
      "Epoch 45 / Batch (2/17) / Loss 1.702215\n",
      "Epoch 45 / Batch (3/17) / Loss 1.679529\n",
      "Epoch 45 / Batch (4/17) / Loss 1.693817\n",
      "Epoch 45 / Batch (5/17) / Loss 1.687120\n",
      "Epoch 45 / Batch (6/17) / Loss 1.696430\n",
      "Epoch 45 / Batch (7/17) / Loss 1.710035\n",
      "Epoch 45 / Batch (8/17) / Loss 1.693372\n",
      "Epoch 45 / Batch (9/17) / Loss 1.696351\n",
      "Epoch 45 / Batch (10/17) / Loss 1.683934\n",
      "Epoch 45 / Batch (11/17) / Loss 1.700703\n",
      "Epoch 45 / Batch (12/17) / Loss 1.705705\n",
      "Epoch 45 / Batch (13/17) / Loss 1.695330\n",
      "Epoch 45 / Batch (14/17) / Loss 1.710743\n",
      "Epoch 45 / Batch (15/17) / Loss 1.707003\n",
      "Epoch 45 / Batch (16/17) / Loss 1.704255\n",
      "Epoch 45 / Batch (17/17) / Loss 1.712297\n",
      "Epoch 45 complete, 12.052214622497559 seconds elapsed\n",
      "Epoch 46 starting..\n",
      "Epoch 46 / Batch (1/17) / Loss 1.701990\n",
      "Epoch 46 / Batch (2/17) / Loss 1.691557\n",
      "Epoch 46 / Batch (3/17) / Loss 1.686813\n",
      "Epoch 46 / Batch (4/17) / Loss 1.696628\n",
      "Epoch 46 / Batch (5/17) / Loss 1.716601\n",
      "Epoch 46 / Batch (6/17) / Loss 1.699005\n",
      "Epoch 46 / Batch (7/17) / Loss 1.703104\n",
      "Epoch 46 / Batch (8/17) / Loss 1.701388\n",
      "Epoch 46 / Batch (9/17) / Loss 1.695648\n",
      "Epoch 46 / Batch (10/17) / Loss 1.704569\n",
      "Epoch 46 / Batch (11/17) / Loss 1.670537\n",
      "Epoch 46 / Batch (12/17) / Loss 1.705584\n",
      "Epoch 46 / Batch (13/17) / Loss 1.690086\n",
      "Epoch 46 / Batch (14/17) / Loss 1.674388\n",
      "Epoch 46 / Batch (15/17) / Loss 1.699052\n",
      "Epoch 46 / Batch (16/17) / Loss 1.696090\n",
      "Epoch 46 / Batch (17/17) / Loss 1.701320\n",
      "Epoch 46 complete, 10.918588876724243 seconds elapsed\n",
      "Epoch 47 starting..\n",
      "Epoch 47 / Batch (1/17) / Loss 1.691306\n",
      "Epoch 47 / Batch (2/17) / Loss 1.687041\n",
      "Epoch 47 / Batch (3/17) / Loss 1.685270\n",
      "Epoch 47 / Batch (4/17) / Loss 1.710355\n",
      "Epoch 47 / Batch (5/17) / Loss 1.688445\n",
      "Epoch 47 / Batch (6/17) / Loss 1.684241\n",
      "Epoch 47 / Batch (7/17) / Loss 1.668459\n",
      "Epoch 47 / Batch (8/17) / Loss 1.688261\n",
      "Epoch 47 / Batch (9/17) / Loss 1.696327\n",
      "Epoch 47 / Batch (10/17) / Loss 1.684378\n",
      "Epoch 47 / Batch (11/17) / Loss 1.682069\n",
      "Epoch 47 / Batch (12/17) / Loss 1.683354\n",
      "Epoch 47 / Batch (13/17) / Loss 1.680457\n",
      "Epoch 47 / Batch (14/17) / Loss 1.694378\n",
      "Epoch 47 / Batch (15/17) / Loss 1.685169\n",
      "Epoch 47 / Batch (16/17) / Loss 1.690130\n",
      "Epoch 47 / Batch (17/17) / Loss 1.686319\n",
      "Epoch 47 complete, 11.113113164901733 seconds elapsed\n",
      "Epoch 48 starting..\n",
      "Epoch 48 / Batch (1/17) / Loss 1.672711\n",
      "Epoch 48 / Batch (2/17) / Loss 1.694653\n",
      "Epoch 48 / Batch (3/17) / Loss 1.689026\n",
      "Epoch 48 / Batch (4/17) / Loss 1.698223\n",
      "Epoch 48 / Batch (5/17) / Loss 1.689198\n",
      "Epoch 48 / Batch (6/17) / Loss 1.677096\n",
      "Epoch 48 / Batch (7/17) / Loss 1.698832\n",
      "Epoch 48 / Batch (8/17) / Loss 1.696956\n",
      "Epoch 48 / Batch (9/17) / Loss 1.681977\n",
      "Epoch 48 / Batch (10/17) / Loss 1.718221\n",
      "Epoch 48 / Batch (11/17) / Loss 1.685341\n",
      "Epoch 48 / Batch (12/17) / Loss 1.707804\n",
      "Epoch 48 / Batch (13/17) / Loss 1.688736\n",
      "Epoch 48 / Batch (14/17) / Loss 1.673243\n",
      "Epoch 48 / Batch (15/17) / Loss 1.701584\n",
      "Epoch 48 / Batch (16/17) / Loss 1.675452\n",
      "Epoch 48 / Batch (17/17) / Loss 1.689484\n",
      "Epoch 48 complete, 11.206429958343506 seconds elapsed\n",
      "Epoch 49 starting..\n",
      "Epoch 49 / Batch (1/17) / Loss 1.667831\n",
      "Epoch 49 / Batch (2/17) / Loss 1.680253\n",
      "Epoch 49 / Batch (3/17) / Loss 1.682776\n",
      "Epoch 49 / Batch (4/17) / Loss 1.683904\n",
      "Epoch 49 / Batch (5/17) / Loss 1.680328\n",
      "Epoch 49 / Batch (6/17) / Loss 1.696835\n",
      "Epoch 49 / Batch (7/17) / Loss 1.686903\n",
      "Epoch 49 / Batch (8/17) / Loss 1.701976\n",
      "Epoch 49 / Batch (9/17) / Loss 1.704201\n",
      "Epoch 49 / Batch (10/17) / Loss 1.707786\n",
      "Epoch 49 / Batch (11/17) / Loss 1.700252\n",
      "Epoch 49 / Batch (12/17) / Loss 1.693491\n",
      "Epoch 49 / Batch (13/17) / Loss 1.693668\n",
      "Epoch 49 / Batch (14/17) / Loss 1.686027\n",
      "Epoch 49 / Batch (15/17) / Loss 1.702266\n",
      "Epoch 49 / Batch (16/17) / Loss 1.689639\n",
      "Epoch 49 / Batch (17/17) / Loss 1.690351\n",
      "Epoch 49 complete, 11.418017148971558 seconds elapsed\n",
      "Epoch 50 starting..\n",
      "Epoch 50 / Batch (1/17) / Loss 1.685762\n",
      "Epoch 50 / Batch (2/17) / Loss 1.692014\n",
      "Epoch 50 / Batch (3/17) / Loss 1.691202\n",
      "Epoch 50 / Batch (4/17) / Loss 1.685898\n",
      "Epoch 50 / Batch (5/17) / Loss 1.692247\n",
      "Epoch 50 / Batch (6/17) / Loss 1.682412\n",
      "Epoch 50 / Batch (7/17) / Loss 1.716837\n",
      "Epoch 50 / Batch (8/17) / Loss 1.686781\n",
      "Epoch 50 / Batch (9/17) / Loss 1.691645\n",
      "Epoch 50 / Batch (10/17) / Loss 1.708002\n",
      "Epoch 50 / Batch (11/17) / Loss 1.680395\n",
      "Epoch 50 / Batch (12/17) / Loss 1.665197\n",
      "Epoch 50 / Batch (13/17) / Loss 1.675645\n",
      "Epoch 50 / Batch (14/17) / Loss 1.686304\n",
      "Epoch 50 / Batch (15/17) / Loss 1.681491\n",
      "Epoch 50 / Batch (16/17) / Loss 1.681932\n",
      "Epoch 50 / Batch (17/17) / Loss 1.692051\n",
      "Epoch 50 complete, 11.310213088989258 seconds elapsed\n"
     ]
    }
   ],
   "source": [
    "n_epochs = 50\n",
    "loss_values =[]\n",
    "for epoch in range(n_epochs):\n",
    "    print('Epoch {} starting..'.format(epoch+1))\n",
    "    epoch_start = time.time()\n",
    "    classes_epoch, imgs_epoch = ut.shuffle_xy(classes_train,imgs_train) \n",
    "\n",
    "    n_batches = len(classes_epoch)//(n_classes*batch_size)\n",
    "    for batch in range(n_batches):\n",
    "        classes_batch = classes_epoch[batch*n_classes*batch_size:(batch+1)*n_classes*batch_size]\n",
    "        imgs_batch = imgs_epoch[batch*n_classes*batch_size:(batch+1)*n_classes*batch_size]\n",
    "\n",
    "        Xl_batch, yl_batch = [], []\n",
    "        for episode in range(batch_size):\n",
    "            imgs_ep = imgs_batch[episode*n_classes:(episode+1)*n_classes]\n",
    "            Xl_ep, yl_ep = [], []\n",
    "            for ind, cat in enumerate(imgs_ep):\n",
    "                for arr in cat:\n",
    "                    Xl_ep.append(arr)\n",
    "                    yl_ep.append(ut.one_hot(ind,n_classes))\n",
    "            Xl_shuff, yl_shuff = ut.shuffle_xy(Xl_ep,yl_ep)\n",
    "            X_arr, y_arr = np.asarray(Xl_shuff), np.asarray(yl_shuff)\n",
    "            Xl_batch.append(X_arr)\n",
    "            yl_batch.append(y_arr)\n",
    "        X_train, y_train = np.asarray(Xl_batch), np.asarray(yl_batch)\n",
    "        X_mann, y_mann = get_graph_input(X_train, y_train)\n",
    "        optimizer.zero_grad()\n",
    "        M, h, c, wu, wr, r, o = net(X_mann)\n",
    "        outs = o.view(100*16, 5)\n",
    "        gt = y_mann.view(100*16)\n",
    "        loss = criterion(outs, gt)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        loss_d = loss.item()\n",
    "        _, predicted = torch.max(outs.data, 1)\n",
    "        correct = (predicted == gt).sum().item()\n",
    "        loss_values.append(loss_d)\n",
    "        if(epoch%10==0):\n",
    "            print('Epoch %d'%(epoch+1),'/ Batch (%d/%d)'%(batch+1,n_batches),'/ Loss %f'%(loss_d))\n",
    "\n",
    "    epoch_end = time.time()\n",
    "    time_elapsed = epoch_end-epoch_start\n",
    "    print('Epoch {} complete,'.format(epoch+1),time_elapsed,'seconds elapsed')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "#### Visualizations "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD8CAYAAAB5Pm/hAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xl8VPW9//HXJ3uAsAdEEEFZXFCjRsRK1boiWtfW6q11qS211da2t/39sL29Lj/bn22ttraK1brV21KvWtyqVYt6qVs1KCIKCChKlCXsSwiQ5HP/mDOTIWQjM8mZOfN+Ph555JzvfM/MZ04m75x858z5mrsjIiLRlRd2ASIi0rUU9CIiEaegFxGJOAW9iEjEKehFRCJOQS8iEnEKehGRiFPQi4hEnIJeRCTiCsIuAGDgwIE+YsSIsMsQEckqs2fPXu3u5e31y4igHzFiBFVVVWGXISKSVczso47009CNiEjEKehFRCJOQS8iEnEZMUYvItGzY8cOqqurqaurC7uUrFdSUsKwYcMoLCzs1PYKehHpEtXV1ZSVlTFixAjMLOxyspa7s2bNGqqrqxk5cmSn7kNDNyLSJerq6hgwYIBCPkVmxoABA1L6z0hBLyJdRiGfHqnux6wO+k/Xb+XmZxfy4eotYZciIpKxsjro127Zzq3PL+b9lZvCLkVEJGNlddD3KY29A71h646QKxGRTLR+/Xpuv/323d5u8uTJrF+/fre3u+SSS3j44Yd3e7uult1B3yMI+loFvYjsqrWgb2hoaHO7p556ir59+3ZVWd0uq0+v7FVUQJ7piF4k0133xLu89+nGtN7nAXv25prPH9hmn6lTp7JkyRIqKiooLCykV69eDBkyhDlz5vDee+9x1llnsWzZMurq6rjqqquYMmUK0HT9rc2bN3PqqacyceJEXnnlFYYOHcpjjz1GaWlpu/XNnDmTH/zgB9TX13PEEUcwbdo0iouLmTp1Ko8//jgFBQWcfPLJ3HTTTTz00ENcd9115Ofn06dPH2bNmpWWfRSX1UGfl2f0KS1U0ItIi2688UbmzZvHnDlzePHFFznttNOYN29e4nz0e+65h/79+7N161aOOOIIzj33XAYMGLDTfSxatIjp06dz1113cd555/HII49w4YUXtvm4dXV1XHLJJcycOZMxY8Zw0UUXMW3aNC666CJmzJjBggULMLPE8ND111/PM888w9ChQzs1ZNSedoPezO4BTgdWufu4oO1BYGzQpS+w3t0rzGwEMB9YGNz2mrtfnu6ikw3uXcIrS1bj7jqVSyRDtXfk3V3Gjx+/04eObr31VmbMmAHAsmXLWLRo0S5BP3LkSCoqKgA4/PDDWbp0abuPs3DhQkaOHMmYMWMAuPjii7ntttu48sorKSkp4Wtf+xqnnXYap59+OgBHH300l1xyCeeddx7nnHNOOp7qTjoyRn8fMCm5wd2/5O4V7l4BPAL8NenmJfHbujrkASaN24MlNVvYuqPtMTcRkZ49eyaWX3zxRf7xj3/w6quv8vbbb3PooYe2+KGk4uLixHJ+fj719fXtPo67t9heUFDA66+/zrnnnsujjz7KpEmxaL3jjju44YYbWLZsGRUVFaxZs2Z3n1qb2j2id/dZwZH6Lix2CH0ecHxaq9oN5WWxH8Kmunp6FGX1SJSIpFlZWRmbNrV8+vWGDRvo168fPXr0YMGCBbz22mtpe9z99tuPpUuXsnjxYkaNGsUDDzzAsccey+bNm6mtrWXy5MlMmDCBUaNGAbBkyRKOPPJIjjzySJ544gmWLVu2y38WqUg1GT8LrHT3RUltI83sLWAj8B/u/s8UH6NNvYpjT2FTXT2De3flI4lIthkwYABHH30048aNo7S0lMGDBydumzRpEnfccQcHH3wwY8eOZcKECWl73JKSEu69916++MUvJt6Mvfzyy1m7di1nnnkmdXV1uDu33HILAD/84Q9ZtGgR7s4JJ5zAIYcckrZaAKy1fzF26hQ7on8yPkaf1D4NWOzuvwrWi4Fe7r7GzA4HHgUOdPdd3m43synAFIDhw4cf/tFHHZooZRcvLFjFpfe9wYxvfYZDh/fr1H2ISPrNnz+f/fffP+wyIqOl/Wlms929sr1tO30evZkVAOcAD8bb3H2bu68JlmcDS4AxLW3v7ne6e6W7V5aXtzvlYavKSpqO6EVEZFepDN2cCCxw9+p4g5mVA2vdvcHM9gFGAx+kWGObykpiH5pS0ItId7niiit4+eWXd2q76qqruPTSS0OqqG0dOb1yOnAcMNDMqoFr3P1u4HxgerPuxwDXm1k90ABc7u5r01vyznoW5wOwZZuCXiTTRPW059tuu61bH68jQ+xt6chZNxe00n5JC22PEDvdstvEz7TR6ZUimaWkpIQ1a9bomvQpik88UlJS0un7yPrzEUsLY0f0CnqRzDJs2DCqq6upqakJu5SsF59KsLOyPuiLC2LvJ2/drqAXySSFhYWdnvpO0iurr14JsevdlBTmUacjehGRFmV90ENs+KZWR/QiIi2KTNBrjF5EpGWRCPqSIgW9iEhrIhH0pYX51GnoRkSkRZEI+qF9S5m5YBULV2iScBGR5iIR9BeMHw7A8wtWhVyJiEjmiUTQHzsmdlE0jdOLiOwqEkGfl2excXoFvYjILiIR9AClRfnUbteFzUREmotO0Bfms3V7Y9hliIhknOgEfZGGbkREWhKdoC/U0I2ISEsiFfQ660ZEZFfRCfqifNbX7mBD7Y6wSxERySiRCfo9+5ayYMUmjrpxZtiliIhklMgE/TGjBwLocsUiIs20G/Rmdo+ZrTKzeUlt15rZJ2Y2J/ianHTb1Wa22MwWmtkpXVV4c5PG7cHkg/ZIzDglIiIxHUnF+4BJLbTf4u4VwddTAGZ2AHA+cGCwze1mlp+uYttiZowZXMa2+kYaGlObMV1EJEraDXp3nwWs7eD9nQn8xd23ufuHwGJgfAr17ZYeRZooXESkuVTGOa40s7nB0E6/oG0osCypT3XQtgszm2JmVWZWla5Z4nsUxeY61/n0IiJNOhv004B9gQpgOfCroN1a6NviOIq73+nule5eWV5e3skydhY/oq/dpiN6EZG4TgW9u6909wZ3bwTuoml4phrYK6nrMODT1ErsuETQ68wbEZGETgW9mQ1JWj0biJ+R8zhwvpkVm9lIYDTwemoldlxpMHSzdYeGbkRE4gra62Bm04HjgIFmVg1cAxxnZhXEhmWWAt8AcPd3zey/gfeAeuAKd++2w+uS4NTKbTt0FUsRkbh2g97dL2ih+e42+v8U+GkqRXVWUTzoGxT0IiJxkfp0UZGO6EVEdhGpoC8uiL0Zu11H9CIiCREL+tjT2V6voBcRiYtU0CeGbup1eqWISFykgl5H9CIiu4pU0Dcd0SvoRUTiohX0+TqiFxFpLlJBXxAE/c3PvR9yJSIimSNSQZ9MV7AUEYmJbNAvXV0bdgkiIhkhckE/tG8pAB+v3RJyJSIimSFyQf/wN48CYO2WHSFXIiKSGSIX9P16FAGwrnZ7yJWIiGSGyAV9SWE+PYryWbtFQS8iAhEMeojNMHX3Sx9Sr4ubiYhEM+jHDi4DYNWmbSFXIiISvkgG/dRT9wNgxca6kCsREQlfJIN+cO8SAM65/RXWaaxeRHJcu0FvZveY2Sozm5fU9kszW2Bmc81shpn1DdpHmNlWM5sTfN3RlcW3Zo8+JYnlj9bqg1Mikts6ckR/HzCpWdtzwDh3Pxh4H7g66bYl7l4RfF2enjJ3T78ehYnlzXW6FIKI5LZ2g97dZwFrm7U96+7xBH0NGNYFtXWamSWWV2/WG7IiktvSMUb/VeDppPWRZvaWmf2PmX22tY3MbIqZVZlZVU1NTRrKaNmqTXpDVkRyW0pBb2Y/BuqBPwVNy4Hh7n4o8H3gz2bWu6Vt3f1Od69098ry8vJUymjRu9edQo+ifBat3Jz2+xYRySadDnozuxg4HfiyuzuAu29z9zXB8mxgCTAmHYXurp7FBRwxoj+vLFmjOWRFJKd1KujNbBLwf4Ez3L02qb3czPKD5X2A0cAH6Si0My4YvxefrN/KK4vXhFWCiEjoOnJ65XTgVWCsmVWb2WXA74Ay4Llmp1EeA8w1s7eBh4HL3X1ti3fcDY4bO4jCfOONpaGVICISuoL2Orj7BS00391K30eAR1ItKl1KCvMZVFbC8g16Q1ZEclckPxmbrLysmBlvfcKY/3i6/c4iIhEU+aAfVFYMwPb6RhoaPeRqRES6X/SDvndxYlkThotILop+0Jc1XfemdrtOsxSR3BP5oC8tzE8sK+hFJBdFPuiTbdmmoRsRyT2RD/oLjhzOiAE9AB3Ri0huinzQ9you4OYvVQB6M1ZEclPkgx5iYQ+wSdemF5EclBNBP7BX7BTLb09/Sxc4E5GckxNBnzzj1MoNmohERHJLTgR98oxTKzbqujcikltyIugBTjtoCADLN2wNuRIRke6VM0H/8y8cDMAKXclSRHJMzgR9r+ICehUXsHTNlrBLERHpVjkT9ACbt9Uz/fVlrNuyPexSRES6TU4FfdynGqcXkRySU0H/p68dCcBpt77EghUbQ65GRKR7dCjozeweM1tlZvOS2vqb2XNmtij43i9oNzO71cwWm9lcMzusq4rfXf16FCWWL39gdoiViIh0n44e0d8HTGrWNhWY6e6jgZnBOsCpwOjgawowLfUy06Nv0genVm3SB6dEJDd0KOjdfRawtlnzmcD9wfL9wFlJ7X/0mNeAvmY2JB3FpqpPaVPQ125vYEdDY4jViIh0j1TG6Ae7+3KA4PugoH0osCypX3XQFrqexQX8v7PGMeWYfQD4ZN1WturSxSIScV3xZqy10LbLrNxmNsXMqsysqqampgvKaNlXJuzNIcP6AnDcTS9yzC9f6LbHFhEJQypBvzI+JBN8XxW0VwN7JfUbBnzafGN3v9PdK929sry8PIUydl/P4qbpBWs0Vi8iEZdK0D8OXBwsXww8ltR+UXD2zQRgQ3yIJ1NMHDWQ4f17hF2GiEi36OjpldOBV4GxZlZtZpcBNwInmdki4KRgHeAp4ANgMXAX8K20V52igvw8vnvi6MS63pQVkSgr6Egnd7+glZtOaKGvA1ekUlR3KC5oGr5ZvXkbQ/qUhliNiEjXyalPxiY7fO9+ieXf/88HxP4+iYhET84G/R59Snjy2xMBuO+Vpbz7qS6JICLRlLNBDzCkT0li+f2Vm0KsRESk6+R00A/oVcyfvx670NmHq3WdehGJppwOeoDP7DsQgN8+v5jVm3VOvYhET84HfbL/eu2jsEsQEUk7BX2S5es1n6yIRI+CPsmz760IuwQRkbRT0ANXn7ofAOtqd+hTsiISOQp64BvH7sv1Zx4IwMqNdWyr16WLRSQ6FPSBvsE0gxN//gKn3DIr5GpERNJHQR/olzTN4NI1tSFWIiKSXgr6wJjBZTut69o3IhIVCvrA4N4llBQ27Y71tTtCrEZEJH0U9Ekmj2uaw3zNlu0hViIikj4K+iQ/O+cgfnb2QQCc8buX2LKtPuSKRERSp6BPUlKYz8HD+gBQu72B+ct16WIRyX4K+mb69yxKLK/erOEbEcl+nQ56MxtrZnOSvjaa2XfN7Foz+ySpfXI6C+5qg8qKOWBIbwBqdDVLEYmATge9uy909wp3rwAOB2qBGcHNt8Rvc/en0lFodynIz+OJb08kz6Bmoy5yJiLZL11DNycAS9w9Etf5zc8z+vcs5oHXPmLhCs08JSLZLV1Bfz4wPWn9SjOba2b3mFm/1jbKZBu2bmdd7Q5O+fUsGhv14SkRyV4pB72ZFQFnAA8FTdOAfYEKYDnwq1a2m2JmVWZWVVNTk2oZabejoSncv//fc0KsREQkNek4oj8VeNPdVwK4+0p3b3D3RuAuYHxLG7n7ne5e6e6V5eXlaSgjvSYftEdi+dE5n9Kgo3oRyVLpCPoLSBq2MbMhSbedDcxLw2N0u99ecBgLb5jEL849GIBla3WhMxHJTgWpbGxmPYCTgG8kNf/CzCoAB5Y2uy1r5OcZ+Xn5jNkjdrGz91duYsTAniFXJSKy+1IKenevBQY0a/tKShVlmNGDegHw+odrOfnAPdrpLSKSefTJ2Hb0LI79LfzDSx8y75MNIVcjIrL7FPQd8NWjRwJw+m9f4hd/XxByNSIiu0dB3wH/+fkD6FmUD8DtLy4JuRoRkd2joO+gwgLtKhHJTkqvDirM164Skeyk9Oqg0sL8xPJbH69j8SpdA0dEsoOCvoOG9ClJLJ99+yucePOsEKsREek4BX0HXfG5UWGXICLSKQr6DjpmTDl3XHh42GWIiOw2Bf1uOOmAwWGXICKy2xT0uyE/z3jiyonAzm/OiohkMgX9bjpoWB++ddy+bN3RwO+eX6RJSUQk4ynoO6GspBCAm559n/d1mqWIZDgFfSd8dvTAxPJ3/6LZp0QksynoO2Hc0D788guxCUkWaPJwEclwCvpOOnZM0/SHmmZQRDKZgr6T+vcsSiyvq90eYiUiIm1T0HdSQX4ev/u3QwF4+p3lOqoXkYyloE/BcWMHAfCTx97l55qQREQyVMpBb2ZLzewdM5tjZlVBW38ze87MFgXf+6VeaubpVdw05e7T85aHWImISOvSdUT/OXevcPfKYH0qMNPdRwMzg/VIW7Z2KzPnrwy7DBGRXXTV0M2ZwP3B8v3AWV30OKF78tsTE8uX3V8VYiUiIi1LR9A78KyZzTazKUHbYHdfDhB8H9R8IzObYmZVZlZVU1OThjLCMW5oHy4YPzyxvqluR4jViIjsKh1Bf7S7HwacClxhZsd0ZCN3v9PdK929sry8vP0NMthlE0cwcVTs07Jzlq0PuRoRkZ2lHPTu/mnwfRUwAxgPrDSzIQDB91WpPk4mGzWojGkXHoYZzP5oXdjliIjsJKWgN7OeZlYWXwZOBuYBjwMXB90uBh5L5XGyQVlJIWMHl/HMuyvZ0dAYdjkiIgmpHtEPBl4ys7eB14G/ufvfgRuBk8xsEXBSsB55X//sPsxfvpFrH3837FJERBIK2u/SOnf/ADikhfY1wAmp3Hc2OvfwYTzyZjV/fv1jvnncvgzr1yPskkRE9MnYdJs0bg/cYeLPXwi7FBERQEGfdqMG9Uosz/5obYiViIjEKOjTrHLv/onlc6e9yprN20KsRkREQZ92RQV5XHfGgYn1w2/4h87CEZFQKei7wOcP2XOn9Y/W1IZUiYiIgr5L9O9ZxH2XHpFYr16noBeR8Cjou0hRQdOuveTeN3jtgzUhViMiuUxB30X69Sjaaf2y+95gQ60ueCYi3U9B30X2H9Kb6V+fwPs3nMpT3/ksW7Y38KfXPwq7LBHJQQr6LnTUvgMoKsjjgD17c9jwvvx93oqwSxKRHKSg7ybH7zeIudUbGDH1b0z42UyqlurDVCLSPRT03SQ+kTjAio11/NdrGsYRke6hoO8mB+7ZO7E8ZnAvHp3zKX/45wchViQiuUJB303MjGe/dwwvTz2ezx8c+0DVDX+bH3JVIpILFPTdaMzgMob2LWXKsfswuHcxAPv95GmenPtpyJWJSJQp6ENQXJDPjybvD0Ddjkau/PNbuHvIVYlIVCnoQ/L5g/fkOyeMTqzfOesD6nXxMxHpAgr6kOTlGd8/aQy/Ob8CgP//9AKmvbgk5KpEJIo6HfRmtpeZvWBm883sXTO7Kmi/1sw+MbM5wdfk9JUbPWdWDOWSz4wA4KXFq8MtRkQiKZU5Y+uBf3f3N82sDJhtZs8Ft93i7jelXl5uuPaMA/lg9RZmvV/DzPkrOWH/wWGXJCIR0ukjendf7u5vBsubgPnA0HQVlmsGl8XOwrns/iqu/PObugCaiKRNWsbozWwEcCjwr6DpSjOba2b3mFm/VraZYmZVZlZVU1OTjjKy2jeP2zex/OTc5Rxy/bM0NOpMHBFJXcpBb2a9gEeA77r7RmAasC9QASwHftXSdu5+p7tXuntleXl5qmVkvX3Ke/HBzyYz+aA9Em0PvrEsxIpEJCpSCnozKyQW8n9y978CuPtKd29w90bgLmB86mXmhrw84/YvH874kbEJxn804x2ee29lyFWJSLZL5awbA+4G5rv7zUntQ5K6nQ3M63x5uemPXx1Pn9JCAO59+UNNLi4iKUnliP5o4CvA8c1OpfyFmb1jZnOBzwHfS0ehuaSkMJ+3rzmZi47am1eWrOHA/3yGxas2hV2WiGQpy4SP3ldWVnpVVVXYZWSc1Zu3UXnDPxLrC2+YRHFBfogViUgmMbPZ7l7ZXj99MjaDDexVzMOXH5VYv+0FfXJWRHafgj7DVY7oz/zrJ7FX/1JunbmIEVP/xvML9AatiHScgj4LlBbl8/erjklMXvLV+zTMJSIdp6DPEj2LC7jv0qYzVRet1JuzItIxCvosUl5WzH2XHgHASbfM4vUPNcG4iLRPQZ9ljh1TzsHD+gBw3u9f5et/rNKkJSLSJgV9ljEz/vjV8dzypUMAeO69lbxdvSHkqkQkkynos1DfHkWcfegwXr36eADOuu1lfaBKRFqloM9iQ/qU8usvVdC7pIATb57FwhUKexHZlYI+y5116FC+9blRAJx9+8shVyMimUhBHwEXTtgbgNrtDRx0zTPMWbY+5IpEJJMo6COgV3EBb/7kJArzjU3b6rno7n+xevO2sMsSkQyhoI+I/j2LeOfaU/jp2ePYVt/Iv931GnOrdWQvIgr6SCkpzOfLR+7NdWccyPsrN3PO7a8wc76uiyOS6xT0EXT++OE8dsXRjBzYk8vur+Knf3sv7JJEJEQK+og6ZK++PPiNoyguyOOuf37I1X99RzNVieQoBX2E9e9ZxNvXnMwpBw5m+usfM/rHT7N8w9awyxKRbqYZpnJA7fZ6jvnFC6zevB2A4f17cP74vThyZH8G9CxmcO8SSos0c5VItunoDFNdFvRmNgn4DZAP/MHdb2ytr4K+69U3NDLjrU94aHZ1i1e9PHH/QRw7ppza7Q2UlxVTWphPn9JC+vQoJM+MPDPMIM9i19uJtZHUHrTlQVF+Hvl5hplhgBkYsX5xzdti/Zr6E6yLSOtCDXozywfeB04CqoE3gAvcvcV3BRX03atuRwMPvrGMT9ZvZcZbn9C7pICt2xv4dENd2KW1KPZHIb688x+PWCO7tCX/AcnLM/LzYg2NjY4DxQV5uEN9oyf6x/94AbT2axG7X2tW0679k1dbuq/m2zf/o9ecB/fjtP37akm1NXawf/yPd6t9Wnh+LfVJ/g5N2yRvm7yf449Z39jY5v239fe++c871hY8LtDoHttvze4//pyS8y95HyQ/n+THiPdo9Nh9Azsd9DQEbdasT/x+8qzpZxmv4fj9BnHtGQe2/iTb0NGgL+jUvbdvPLDY3T8IivkLcCag0z8yQElhPhd/ZgQAP5q8PwANjc4n67ZSUpTHsrW11Dc4W3c0sK2+EXdPvGg96XtDo8eWIdGnvtHZXt8YBGrwS0bSL33QBk2/ZMl94sEUb8Oboip+e+K25G12amv65Y3X3dDoif8aALY3NJJnUJCXl9im0aHBfacATrbzc2mqk1b6J0f2rveV9DybPe+d+iXdS0v/GTUXv69GJxE+7fWN/Uxb6ZP0hyL5sZPrTPRJ+mZNGyW2Tf45NQb7G6Agr43n1MYfgJZ+3smvi/gfk+Sfe6JeD24LwnuX11NwZ8mPkdwnP892+YMS3+et9fHmfYLds095z9afZJp0VdAPBZYlrVcDR3bRY0ka5OcZwwf0AGBQWUnI1YhIOnXVWTet/ffZ1MFsiplVmVlVTU1NF5UhIiJdFfTVwF5J68OAT5M7uPud7l7p7pXl5eVdVIaIiHRV0L8BjDazkWZWBJwPPN5FjyUiIm3okjF6d683syuBZ4idXnmPu7/bFY8lIiJt66o3Y3H3p4Cnuur+RUSkY3QJBBGRiFPQi4hEnIJeRCTiMuKiZmZWA3yUwl0MBFanqZyo0b5pm/ZP27R/WpcJ+2Zvd2/3/PSMCPpUmVlVR673kIu0b9qm/dM27Z/WZdO+0dCNiEjEKehFRCIuKkF/Z9gFZDDtm7Zp/7RN+6d1WbNvIjFGLyIirYvKEb2IiLQiq4PezCaZ2UIzW2xmU8OuJwxmtpeZvWBm883sXTO7Kmjvb2bPmdmi4Hu/oN3M7NZgn801s8PCfQZdz8zyzewtM3syWB9pZv8K9s2DwYX3MLPiYH1xcPuIMOvuDmbW18weNrMFwWvoKL12mpjZ94Lfq3lmNt3MSrLx9ZO1QR9MV3gbcCpwAHCBmR0QblWhqAf+3d33ByYAVwT7YSow091HAzODdYjtr9HB1xRgWveX3O2uAuYnrf8cuCXYN+uAy4L2y4B17j4KuCXoF3W/Af7u7vsBhxDbT3rtAGY2FPgOUOnu44hdoPF8svH14+5Z+QUcBTyTtH41cHXYdYX9BTxGbK7ehcCQoG0IsDBY/j2x+Xvj/RP9ovhFbC6EmcDxwJPEJsVZDRQ0fx0Ru9rqUcFyQdDPwn4OXbhvegMfNn+Oeu0knl98prz+wevhSeCUbHz9ZO0RPS1PVzg0pFoyQvCv4qHAv4DB7r4cIPg+KOiWa/vt18D/ARqD9QHAenevD9aTn39i3wS3bwj6R9U+QA1wbzC09Qcz64leOwC4+yfATcDHwHJir4fZZOHrJ5uDvt3pCnOJmfUCHgG+6+4b2+raQlsk95uZnQ6scvfZyc0tdPUO3BZFBcBhwDR3PxTYQtMwTUtyav8E702cCYwE9gR6Ehu+ai7jXz/ZHPTtTleYK8yskFjI/8nd/xo0rzSzIcHtQ4BVQXsu7bejgTPMbCnwF2LDN78G+ppZfC6G5Oef2DfB7X2Atd1ZcDerBqrd/V/B+sPEgl+vnZgTgQ/dvcbddwB/BT5DFr5+sjnoNV0hsTMhgLuB+e5+c9JNjwMXB8sXExu7j7dfFJxBMQHYEP83PWrc/Wp3H+buI4i9Pp539y8DLwBfCLo13zfxffaFoH9GHJF1BXdfASwzs7FB0wnAe+i1E/cxMMHMegS/Z/H9k32vn7DfJEjxzZLJwPvAEuDHYdcT0j6YSOzfw7nAnOBrMrGxwZnAouB7/6C/ETtbaQnwDrEzCkJ/Ht2wn44DngzrsBOzAAAAbElEQVSW9wFeBxYDDwHFQXtJsL44uH2fsOvuhv1SAVQFr59HgX567ey0f64DFgDzgAeA4mx8/eiTsSIiEZfNQzciItIBCnoRkYhT0IuIRJyCXkQk4hT0IiIRp6AXEYk4Bb2ISMQp6EVEIu5/Aauo+vxAqZJ9AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "def plot_loss(train,name1=\"train_loss\"):\n",
    "    plt.plot(train, label=name1)\n",
    "    plt.legend()\n",
    "\n",
    "plot_loss(loss_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
